Skip to content

Commit

Permalink
Revive "Strict Equality" for assertEquals() (#521)
Browse files Browse the repository at this point in the history
* Make FailException and ComparisonFailException work more similarly.

Previously, FailException had some custom nice-to-have features that
ComparisonFailException didn't have.

* Introduce "strict equality" mode for `assertEquals()` and friends.

Previously, MUnit had a subtyping constraint on `assertEquals(a, b)` so
that it would fail to compile if `a` was not a subtype of `b`. This was
a suboptimal solution because the compile error messages could become
cryptic in some cases. Additionally, this API didn't integrate with
other libaries like Cats that has its own `cats.Eq[A,B]` type-class.

Now, MUnit uses a new `munit.Compare[A,B]` type-class for comparing
values of different types. By default, MUnit provides a "universal"
instance that permits comparison between all types and uses the built-in
`==` method. Users can optionally enable "strict equality" by adding the
compiler option `"-Xmacro-settings.munit.strictEquality"` in Scala 2.
In Scala 3, we use the `Eql[A, B]` type-classes instead to determine
type equality.

* Address review feedback

* Drop strict equality, allow comparison between supertypes/subtypes

This is a fourth attempt at improving strict equality in MUnit
`assertEquals()` assertions.

* First attempt (current release version): require second argument to be
  a supertype of the first argument. This has the flaw that the compile
  error message is cryptic and that the ordering of the arguments affects
  compilation.
* Second attempt: use `Eql[A, B]` in Scala 3 and allow comparing any
  types in Scala 2. This has the flaw that it's a regression in some
  cases for Scala 2 users and that `Eql[A, B]` is not really usable
  in its current form, see related discussion
  https://contributors.scala-lang.org/t/should-multiversal-equality-provide-default-eql-instances/4574
* Third attempt: implement "strict equality" for Scala 2 with a macro
  and `Eql[T, T]` in Scala. This improves the situation for Scala 2,
  but would mean relying on a feature that we can't easily port to Scala 3.
* Fourth attempt (this commit): improve the first attempt (current
  release) by allowing `Compare[A, B]` as long
  as `A <:< B` OR `B <:< A`. This is possible thanks to an observation
  by Gabriele Petronella that it's possible to layer the implicits to
  avoid diverging implicit search.

The benefit of the fourth approach is that it works the same way for
Scala 3 and Scala 3. It's very nice that we can avoid macros as well.

* Address review feedback

* Run scalafmtSbt

* Remove unused import

* Fix dotty tests in AssertionsSuite

The Scala 3 (dotty) tests now use compareSubtypeWithSupertype instead
of compareSupertypeWithSubtype. Additionally, the "unrelated" test was
not seeing the context code above and so I've moved all the code into
compileErrors.

* Add mima exclusions for assertEquals and co

* Remove unused import in scala-3 MacroCompat

* Reintroduce special-case msgs for comparing arrays

* Reintroduce better string inequality error msgs

* Update Clue deprecation to 1.0

Co-authored-by: Ólafur Páll Geirsson <[email protected]>

* Fix typo in AssertionsSuite test name

Co-authored-by: Ólafur Páll Geirsson <[email protected]>

Co-authored-by: Olafur Pall Geirsson <[email protected]>
Co-authored-by: Ólafur Páll Geirsson <[email protected]>
  • Loading branch information
3 people authored Apr 16, 2022
1 parent 675a25e commit 8558ab1
Show file tree
Hide file tree
Showing 18 changed files with 483 additions and 120 deletions.
55 changes: 39 additions & 16 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ lazy val mimaEnable: List[Def.Setting[_]] = List(
"munit.internal.junitinterface.JUnitComputer.this"
),
// Known breaking changes for MUnit v1
ProblemFilters.exclude[DirectMissingMethodProblem](
"munit.Assertions.assertNotEquals"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"munit.Assertions.assertEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.Assertions.assertNotEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.Assertions.assertEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.FunSuite.assertNotEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.FunSuite.assertEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.FunSuite.munitTestTransform"
),
Expand Down Expand Up @@ -194,22 +212,8 @@ lazy val junit = project
lazy val munit = crossProject(JSPlatform, JVMPlatform, NativePlatform)
.settings(
sharedSettings,
Compile / unmanagedSourceDirectories ++= {
val root = (ThisBuild / baseDirectory).value / "munit"
val base = root / "shared" / "src" / "main"
val result = mutable.ListBuffer.empty[File]
val partialVersion = CrossVersion.partialVersion(scalaVersion.value)
if (isPreScala213(partialVersion)) {
result += base / "scala-pre-2.13"
}
if (isNotScala211(partialVersion)) {
result += base / "scala-post-2.11"
}
if (isScala2(partialVersion)) {
result += base / "scala-2"
}
result.toList
},
Compile / unmanagedSourceDirectories ++=
crossBuildingDirectories("munit", "main").value,
libraryDependencies ++= List(
"org.scala-lang" % "scala-reflect" % {
if (isScala3Setting.value) scala213
Expand Down Expand Up @@ -308,6 +312,8 @@ lazy val tests = crossProject(JSPlatform, JVMPlatform, NativePlatform)
((ThisBuild / baseDirectory).value / "tests" / "shared" / "src" / "main").getAbsolutePath.toString,
scalaVersion
),
Test / unmanagedSourceDirectories ++=
crossBuildingDirectories("tests", "test").value,
publish / skip := true
)
.nativeConfigure(sharedNativeConfigure)
Expand Down Expand Up @@ -348,3 +354,20 @@ lazy val docs = project
Global / excludeLintKeys ++= Set(
mimaPreviousArtifacts
)
def crossBuildingDirectories(name: String, config: String) =
Def.setting[Seq[File]] {
val root = (ThisBuild / baseDirectory).value / name
val base = root / "shared" / "src" / config
val result = mutable.ListBuffer.empty[File]
val partialVersion = CrossVersion.partialVersion(scalaVersion.value)
if (isPreScala213(partialVersion)) {
result += base / "scala-pre-2.13"
}
if (isNotScala211(partialVersion)) {
result += base / "scala-post-2.11"
}
if (isScala2(partialVersion)) {
result += base / "scala-2"
}
result.toList
}
24 changes: 15 additions & 9 deletions docs/assertions.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,29 @@ assertEquals(
Comparing two values of different types is a compile error.

```scala mdoc:fail
assertEquals(1, "")
assertEquals(Option("message"), "message")
```

The "expected" value (second argument) must be a subtype of the "obtained" value
(first argument).
It's a compile error even if the comparison is true at runtime.

```scala mdoc
assertEquals(Option(1), Some(1))
```scala mdoc:fail
assertEquals(List(1), Vector(1))
```

It's a compile error if you swap the order of the arguments.

```scala mdoc:fail
assertEquals(Some(1), Option(1))
assertEquals('a', 'a'.toInt)
```

It's OK to compare two types as long as one argument is a subtype of the other
type.

```scala mdoc
assertEquals(Option(1), Some(1)) // OK
assertEquals(Some(1), Option(1)) // OK
```

Use `assertEquals[Any, Any]` if you really want to compare two different types.
Use `assertEquals[Any, Any]` if you think it's OK to compare the two types at
runtime.

```scala mdoc
val right1: Either[String , Int] = Right(42)
Expand Down
79 changes: 15 additions & 64 deletions munit/shared/src/main/scala/munit/Assertions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,6 @@ trait Assertions extends MacroCompat.CompileErrorMacro {

def munitAnsiColors: Boolean = true

private def munitComparisonHandler(
actualObtained: Any,
actualExpected: Any
): ComparisonFailExceptionHandler =
new ComparisonFailExceptionHandler {
override def handle(
message: String,
unusedObtained: String,
unusedExpected: String,
loc: Location
): Nothing = failComparison(message, actualObtained, actualExpected)(loc)
}

private def munitFilterAnsi(message: String): String =
if (munitAnsiColors) message
else AnsiColors.filterAnsi(message)
Expand Down Expand Up @@ -67,20 +54,25 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
Diffs.assertNoDiff(
obtained,
expected,
munitComparisonHandler(obtained, expected),
ComparisonFailExceptionHandler.fromAssertions(this, Clues.empty),
munitPrint(clue),
printObtainedAsStripMargin = true
)
}
}

/**
* Asserts that two elements are not equal according to the `Compare[A, B]` type-class.
*
* By default, uses `==` to compare values.
*/
def assertNotEquals[A, B](
obtained: A,
expected: B,
clue: => Any = "values are the same"
)(implicit loc: Location, ev: A =:= B): Unit = {
)(implicit loc: Location, compare: Compare[A, B]): Unit = {
StackTraces.dropInside {
if (obtained == expected) {
if (compare.isEqual(obtained, expected)) {
failComparison(
s"${munitPrint(clue)} expected same: $expected was not: $obtained",
obtained,
Expand All @@ -91,32 +83,17 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
}

/**
* Asserts that two elements are equal using `==` equality.
*
* The "expected" value (second argument) must have the same type or be a
* subtype of the "obtained" value (first argument). For example:
* {{{
* assertEquals(Option(1), Some(1)) // OK
* assertEquals(Some(1), Option(1)) // Error: Option[Int] is not a subtype of Some[Int]
* }}}
* Asserts that two elements are equal according to the `Compare[A, B]` type-class.
*
* Use `assertEquals[Any, Any](a, b)` as an escape hatch to compare two
* values of different types. For example:
* {{{
* val a: Either[List[String], Int] = Right(42)
* val b: Either[String, Int] = Right(42)
* assertEquals[Any, Any](a, b) // OK
* assertEquals(a, b) // Error: Either[String, Int] is not a subtype of Either[List[String], Int]
* }}}
* By default, uses `==` to compare values.
*/
def assertEquals[A, B](
obtained: A,
expected: B,
clue: => Any = "values are not the same"
)(implicit loc: Location, ev: B <:< A): Unit = {
)(implicit loc: Location, compare: Compare[A, B]): Unit = {
StackTraces.dropInside {
if (obtained != expected) {

if (!compare.isEqual(obtained, expected)) {
(obtained, expected) match {
case (a: Array[_], b: Array[_]) if a.sameElements(b) =>
// Special-case error message when comparing arrays. See
Expand All @@ -137,34 +114,7 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
)
case _ =>
}

Diffs.assertNoDiff(
munitPrint(obtained),
munitPrint(expected),
munitComparisonHandler(obtained, expected),
munitPrint(clue),
printObtainedAsStripMargin = false
)
// try with `.toString` in case `munitPrint()` produces identical formatting for both values.
Diffs.assertNoDiff(
obtained.toString(),
expected.toString(),
munitComparisonHandler(obtained, expected),
munitPrint(clue),
printObtainedAsStripMargin = false
)
if (obtained.toString() == expected.toString())
failComparison(
s"values are not equal even if they have the same `toString()`: $obtained",
obtained,
expected
)
else
failComparison(
s"values are not equal, even if their text representation only differs in leading/trailing whitespace and ANSI escape characters: $obtained",
obtained,
expected
)
compare.failEqualsComparison(obtained, expected, clue, loc, this)
}
}
}
Expand Down Expand Up @@ -320,7 +270,8 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
munitFilterAnsi(munitLines.formatLine(loc, message, clues)),
obtained,
expected,
loc
loc,
isStackTracesEnabled = false
)
}

Expand Down
4 changes: 3 additions & 1 deletion munit/shared/src/main/scala/munit/Clue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ class Clue[+T](
override def toString(): String = s"Clue($source, $value)"
}
object Clue extends MacroCompat.ClueMacro {
def empty[T](value: T): Clue[T] = new Clue("", value, "")
@deprecated("use fromValue instead", "1.0.0")
def empty[T](value: T): Clue[T] = fromValue(value)
def fromValue[T](value: T): Clue[T] = new Clue("", value, "")
}
4 changes: 4 additions & 0 deletions munit/shared/src/main/scala/munit/Clues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ import munit.internal.console.Printers
class Clues(val values: List[Clue[_]]) {
override def toString(): String = Printers.print(this)
}
object Clues {
def empty: Clues = new Clues(List())
def fromValue[T](value: T): Clues = new Clues(List(Clue.fromValue(value)))
}
122 changes: 122 additions & 0 deletions munit/shared/src/main/scala/munit/Compare.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package munit

import munit.internal.difflib.Diffs
import munit.internal.difflib.ComparisonFailExceptionHandler
import scala.annotation.implicitNotFound

/**
* A type-class that is used to compare values in MUnit assertions.
*
* By default, uses == and allows comparison between any two types as long
* they have a supertype/subtype relationship. For example:
*
* - Compare[T, T] OK
* - Compare[Some[Int], Option[Int]] OK, subtype
* - Compare[Option[Int], Some[Int]] OK, supertype
* - Compare[List[Int], collection.Seq[Int]] OK, subtype
* - Compare[List[Int], Vector[Int]] Error, requires upcast to `Seq[Int]`
*/
@implicitNotFound(
// NOTE: Dotty ignores this message if the string is formatted as a multiline string """..."""
"Can't compare these two types:\n First type: ${A}\n Second type: ${B}\nPossible ways to fix this error:\n Alternative 1: provide an implicit instance for Compare[${A}, ${B}]\n Alternative 2: upcast either type into `Any` or a shared supertype"
)
trait Compare[A, B] {

/**
* Returns true if the values are equal according to the rules of this `Compare[A, B]` instance.
*
* The default implementation of this method uses `==`.
*/
def isEqual(obtained: A, expected: B): Boolean

/**
* Throws an exception to fail this assertion when two values are not equal.
*
* Override this method to customize the error message. For example, it may
* be helpful to generate an image/HTML file if you're comparing visual
* values. Anything is possible, use your imagination!
*
* @return should ideally throw a org.junit.ComparisonFailException in order
* to support the IntelliJ diff viewer.
*/
def failEqualsComparison(
obtained: A,
expected: B,
title: Any,
loc: Location,
assertions: Assertions
): Nothing = {
val diffHandler = new ComparisonFailExceptionHandler {
override def handle(
message: String,
_obtained: String,
_expected: String,
loc: Location
): Nothing =
assertions.failComparison(
message,
obtained,
expected
)(loc)
}
// Attempt 1: custom pretty-printer that produces multiline output, which is
// optimized for line-by-line diffing.
Diffs.assertNoDiff(
assertions.munitPrint(obtained),
assertions.munitPrint(expected),
diffHandler,
title = assertions.munitPrint(title),
printObtainedAsStripMargin = false
)(loc)

// Attempt 2: try with `.toString` in case `munitPrint()` produces identical
// formatting for both values.
Diffs.assertNoDiff(
obtained.toString(),
expected.toString(),
diffHandler,
title = assertions.munitPrint(title),
printObtainedAsStripMargin = false
)(loc)

// Attempt 3: string comparison is not working, unconditionally fail the test.
if (obtained.toString() == expected.toString())
assertions.failComparison(
s"values are not equal even if they have the same `toString()`: $obtained",
obtained,
expected
)(loc)
else
assertions.failComparison(
s"values are not equal, even if their text representation only differs in leading/trailing whitespace and ANSI escape characters: $obtained",
obtained,
expected
)(loc)
}

}

object Compare extends ComparePriority1 {
private val anyEquality: Compare[Any, Any] = _ == _
def defaultCompare[A, B]: Compare[A, B] =
anyEquality.asInstanceOf[Compare[A, B]]
}

/** Allows comparison between A and B when A is a subtype of B */
trait ComparePriority1 extends ComparePriority2 {
implicit def compareSubtypeWithSupertype[A, B](implicit
ev: A <:< B
): Compare[A, B] = Compare.defaultCompare
}

/**
* Allows comparison between A and B when B is a subtype of A.
*
* This implicit is defined separately from ComparePriority1 in order to avoid
* diverging implicit search when comparing equal types.
*/
trait ComparePriority2 {
implicit def compareSupertypeWithSubtype[A, B](implicit
ev: A <:< B
): Compare[B, A] = Compare.defaultCompare
}
Loading

0 comments on commit 8558ab1

Please sign in to comment.