diff --git a/core/src/main/scala-2/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala b/core/src/main/scala-2/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala index 3aa1bb60..e734e026 100644 --- a/core/src/main/scala-2/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala +++ b/core/src/main/scala-2/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala @@ -30,10 +30,16 @@ trait DiffMagnoliaDerivation { val lType = ctx.split(left)(a => a) val rType = ctx.split(right)(a => a) if (lType == rType) { - lType.typeclass( + val leftTypeClass = lType.typeclass + val contextPath = ModifyPath.Subtype(lType.typeName.owner, lType.typeName.short) + val modifyFromOverride = context + .getOverride(contextPath) + .map(_.asInstanceOf[leftTypeClass.type => leftTypeClass.type]) + .getOrElse(identity[leftTypeClass.type] _) + modifyFromOverride(leftTypeClass)( lType.cast(left), lType.cast(right), - context.getNextStep(ModifyPath.Subtype(lType.typeName.owner, lType.typeName.short)).merge(context) + context.getNextStep(contextPath).merge(context) ) } else { DiffResultValue(lType.typeName.full, rType.typeName.full) diff --git a/core/src/main/scala-3/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala b/core/src/main/scala-3/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala index 40cfe487..8c309a27 100644 --- a/core/src/main/scala-3/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala +++ b/core/src/main/scala-3/com/softwaremill/diffx/generic/DiffMagnoliaDerivation.scala @@ -50,10 +50,16 @@ trait DiffMagnoliaDerivation extends Derivation[Diff] { val lType = ctx.choose(left)(a => a) val rType = ctx.choose(right)(a => a) if (lType.typeInfo == rType.typeInfo) { - lType.typeclass( + val leftTypeClass = lType.typeclass + val contextPath = ModifyPath.Subtype(lType.typeInfo.owner, lType.typeInfo.short) + val modifyFromOverride = context + .getOverride(contextPath) + .map(_.asInstanceOf[leftTypeClass.type => leftTypeClass.type]) + .getOrElse(identity[leftTypeClass.type] _) + modifyFromOverride(leftTypeClass)( lType.cast(left), lType.cast(right), - context.getNextStep(ModifyPath.Subtype(lType.typeInfo.owner, lType.typeInfo.short)).merge(context) + context.getNextStep(contextPath).merge(context) ) } else { DiffResultValue(lType.typeInfo.full, rType.typeInfo.full) diff --git a/core/src/main/scala/com/softwaremill/diffx/DiffContext.scala b/core/src/main/scala/com/softwaremill/diffx/DiffContext.scala index b60af110..7adf2fca 100644 --- a/core/src/main/scala/com/softwaremill/diffx/DiffContext.scala +++ b/core/src/main/scala/com/softwaremill/diffx/DiffContext.scala @@ -22,7 +22,7 @@ case class DiffContext( private def treeOverride[T](nextPath: ModifyPath, tree: Tree[T]) = { tree match { - case Tree.Leaf(v) => Some(v) + case Tree.Leaf(v) => None case Tree.Node(tries) => getOverrideFromNode(nextPath, tries) } } diff --git a/core/src/test/scala/com/softwaremill/diffx/test/DiffModifyIntegrationTest.scala b/core/src/test/scala/com/softwaremill/diffx/test/DiffModifyIntegrationTest.scala index cb21f8d9..426f43b3 100644 --- a/core/src/test/scala/com/softwaremill/diffx/test/DiffModifyIntegrationTest.scala +++ b/core/src/test/scala/com/softwaremill/diffx/test/DiffModifyIntegrationTest.scala @@ -7,6 +7,7 @@ import org.scalatest.matchers.should.Matchers import java.time.Instant import java.util.UUID +import scala.collection.immutable.ListMap class DiffModifyIntegrationTest extends AnyFlatSpec with Matchers with AutoDerivation { val instant: Instant = Instant.now() @@ -293,4 +294,30 @@ class DiffModifyIntegrationTest extends AnyFlatSpec with Matchers with AutoDeriv val d2 = Diff[Family].modify(_.second.name).ignore.modify(_.first).ignore compare(f1, f2)(d2).isIdentical shouldBe true } + + it should "allow to set custom diff to a nested case class field" in { + case class Address(house: Int, street: String) + case class Person(name: String, address: Address) + + val add = Diff.summon[Address] + val d = Diff + .summon[Person] + .modify(_.address) + .setTo(add) + + val a1 = Address(123, "Robin St.") + val a2 = Address(456, "Robin St.") + val p1 = Person("Mason", a1) + val p2 = Person("Mason", a2) + d(p1, p2) shouldBe DiffResultObject( + "Person", + ListMap( + "name" -> IdenticalValue("Mason"), + "address" -> DiffResultObject( + "Address", + ListMap("house" -> DiffResultValue(123, 456), "street" -> IdenticalValue("Robin St.")) + ) + ) + ) + } }