Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extension copy #262

Merged
merged 14 commits into from
Dec 17, 2024
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject)
val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
organization := "com.softwaremill.quicklens",
updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))),
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all"
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros"
ideSkipProject := (scalaVersion.value != scalaIdeaVersion)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ object QuicklensMacros {
def noSuchMember(tpeStr: String, name: String) =
s"$tpeStr has no member named $name"

def noSuitableMember(tpeStr: String, name: String, argNames: Iterable[String]) =
s"$tpeStr has no member $name with parameters ${argNames.mkString("(", ", ", ")")}"

def multipleMatchingMethods(tpeStr: String, name: String, syms: Seq[Symbol]) =
val symsStr = syms.map(s => s" - $s: ${s.termRef.dealias.widen.show}").mkString("\n", "\n", "")
s"Multiple methods named $name found in $tpeStr: $symsStr"
Expand Down Expand Up @@ -109,11 +112,14 @@ object QuicklensMacros {
case (symbol :: tail) => PathTree.Node(Seq(symbol -> Seq(tail.toPathTree)))

enum PathSymbol:
case Field(name: String)
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
case Field(override val name: String)
case Extension(term: Term, override val name: String)
case FunctionDelegate(override val name: String, givn: Term, typeTree: TypeTree, args: List[Term])
def name: String

def equiv(other: Any): Boolean = (this, other) match
case (Field(name1), Field(name2)) => name1 == name2
case (Extension(term1, name1), Extension(term2, name2)) => term1 == term2 && name1 == name2
case (FunctionDelegate(name1, _, typeTree1, args1), FunctionDelegate(name2, _, typeTree2, args2)) =>
name1 == name2 && typeTree1.tpe == typeTree2.tpe && args1 == args2
case _ => false
Expand All @@ -133,6 +139,9 @@ object QuicklensMacros {
/** Method call with one type parameter and using clause */
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
/** Extension method, which is called e.g. as x(_$1) */
case Apply(obj@Select(term, member), Seq(deep)) if obj.symbol.flags.is(Flags.ExtensionMethod) =>
toPath(deep, focus) :+ PathSymbol.Extension(term, member)
/** Field access */
case Apply(deep, idents) =>
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
Expand All @@ -157,43 +166,104 @@ object QuicklensMacros {
def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
tpe.typeSymbol
case tpe if isProductLike(tpe.typeSymbol) =>
if lSym != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) || isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
OndrejSpanel marked this conversation as resolved.
Show resolved Hide resolved
Symbol.noSymbol
}

def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
val mem = sym.fieldMember(name)
if mem != Symbol.noSymbol then mem
else methodSymbolByNameOrError(sym, name)
extension (term: Term)
def appliedToIfNeeded(args: List[Term]): Term =
if args.isEmpty then term else term.appliedToArgs(args)

def symbolAccessorByNameOrError(obj: Term, name: String): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
// opaque types can find members of underlying types - ignore them (see https://github.com/scala/scala3/issues/22143)
val fieldMemberSym = objSymbol.fieldMember(name)
if !objSymbol.flags.is(Flags.Deferred) && fieldMemberSym.exists then
Select(obj, fieldMemberSym)
else
objSymbol.methodMember(name) match
case List(m) =>
Select(obj, m)
case lst =>
report.errorAndAbort(reportMethodError(objSymbol, name, lst))
}

def reportMethodError(sym: Symbol, name: String, lst: List[Symbol], maybeArgNames: Option[Iterable[String]] = None): String = {
(lst, maybeArgNames) match
case (Nil, _) => noSuchMember(sym.name, name)
case (lst, None) => multipleMatchingMethods(sym.name, name, lst)
case (lst, Some(argNames)) => noSuitableMember(sym.name, name, argNames)
}

def methodSymbolByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
case lst => report.errorAndAbort(reportMethodError(sym, name, lst))
}

def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = {
def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
val argNames = argsMap.keys
sym.methodMember(name).filter{ msym =>
allMethods.filter { msym =>
// for copy, we filter out the methods that don't have the desired parameter names
val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
argNames.forall(paramNames.contains)
} match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case lst @ (m :: _) =>
case List(m) => Some(m)
case Nil => None
case lst@(m :: _) =>
// if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic))
syntheticCopies match
case List(mSynth) => mSynth
case _ => m
case List(mSynth) => Some(mSynth)
case _ => Some(m)
}

def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Either[String, Symbol] = {
if !sym.flags.is(Flags.Deferred) then
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
.toRight(reportMethodError(sym, name, memberMethods, Some(argsMap.keys)))
else Left(s"Deferred type ${sym.name}")
}

/**
* @param argsMap normal methods receive one parameter list, extensions methods two, the first one contains the value
* on which the extension is called
* */
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]]) = {
require(argsMap.size == 1 || argsMap.size == 2, s"argsMap.size should be either 1 or 2, got: ${argsMap.size} ($argsMap)")
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol

val typeParams = objTpe.typeArgs
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
.flatten.toList

val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
val i = _i + 1
def defaultMethod: Term =
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
// default values in extension methods take the extension receiver as the first parameter
val defaultMethodArgs = argsMap.dropRight(1).flatMap(_.values)
obj.select(methodSymbol).appliedToIfNeeded(defaultMethodArgs)
n -> v.getOrElse(defaultMethod)
}.toMap

val argLists: List[List[Term]] = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))

if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
)

val withTypeParamsApplied = obj.select(copy).appliedToTypes(typeParams)
argLists.foldLeft(withTypeParamsApplied)(Apply(_, _))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
Expand All @@ -210,15 +280,32 @@ object QuicklensMacros {
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def findCompanionLikeObject(objSymbol: Symbol): Symbol = {
if objSymbol.companionModule.exists then
objSymbol.companionModule
else
val namedFromOwnerScope = objSymbol.owner.fieldMember(objSymbol.name)
if namedFromOwnerScope.flags.is(Flags.Module) then namedFromOwnerScope
else Symbol.noSymbol
}

def hasExtensionNamed(sym: Symbol, methodName: String): List[Symbol] = {
val companionSymbol = findCompanionLikeObject(sym)
if companionSymbol.exists then
companionSymbol.methodMember(methodName).filter(s => s.name == methodName && s.flags.is(Flags.ExtensionMethod))
else
Nil
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").size >= 1
sym.methodMember("copy").nonEmpty || hasExtensionNamed(sym, "copy").nonEmpty
}

def caseClassCopy(
owner: Symbol,
mod: Expr[A => A],
obj: Term,
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
fields: Seq[(PathSymbol.Field | PathSymbol.Extension, Seq[PathTree])]
): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
Expand Down Expand Up @@ -248,50 +335,39 @@ object QuicklensMacros {
}

val elseThrow = '{ throw new IllegalStateException() }.asTerm

ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val fieldMethod = symbolAccessorByNameOrError(objSymbol, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
val fieldMethod = field match {
case PathSymbol.Field(name) =>
symbolAccessorByNameOrError(obj, name)
case PathSymbol.Extension(term, name) =>
val extensionMethod = symbolAccessorByNameOrError(term, name)
Apply(extensionMethod, List(obj))
}
val resTerm: Term = trees.foldLeft[Term](fieldMethod) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap)

val typeParams = objTpe match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)

val args = copyParamNames.zipWithIndex.map { (n, _i) =>
val i = _i + 1
val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString))
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
argsMap.getOrElse(
n,
defaultMethod
)
}.toList

if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
case Right(copy) =>
callMethod(obj, copy, List(argsMap))
case Left(error) =>
val objCompanion = findCompanionLikeObject(objSymbol)
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap).toOption match
OndrejSpanel marked this conversation as resolved.
Show resolved Hide resolved
case Some(copy) =>
// now try to call the extension as a method, assume the object is its first parameter
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
callMethod(Ref(objCompanion), copy, argsWithObj)
case None => report.errorAndAbort(error)
} else
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
}

def applyFunctionDelegate(
Expand Down Expand Up @@ -331,9 +407,9 @@ object QuicklensMacros {
case Nil =>
objTerm

case (_: PathSymbol.Field, _) :: _ =>
val (fs, funs) = pathSymbols.span(_._1.isInstanceOf[PathSymbol.Field])
val fields = fs.collect { case (p: PathSymbol.Field, trees) => p -> trees }
case (_: (PathSymbol.Field | PathSymbol.Extension), _) :: _ =>
val (fs, funs) = pathSymbols.span((ps, _) => ps.isInstanceOf[PathSymbol.Field] || ps.isInstanceOf[PathSymbol.Extension])
val fields = fs.collect { case (p: (PathSymbol.Field | PathSymbol.Extension), trees) => p -> trees }
val withCopiedFields: Term = caseClassCopy(owner, mod, objTerm, fields)
accumulateToCopy(owner, mod, withCopiedFields, funs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ package object quicklens {
def map[A](fa: M[A], f: A => A): M[A] = {
val mapped = fa.view.mapValues(f)
(fa match {
case sfa: SortedMap[K, A] => sfa.sortedMapFactory.from(mapped)(using sfa.ordering)
case sfa: SortedMap[K, A]@unchecked => sfa.sortedMapFactory.from(mapped)(using sfa.ordering)
case _ => mapped.to(fa.mapFactory)
}).asInstanceOf[M[A]]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package com.softwaremill.quicklens
package test

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -33,7 +34,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
def paths(paths: Paths): Docs = copy(paths = paths)
}
val docs = Docs()
docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
val r = docs.modify(_.paths.pathItems).using(m => m + ("a" -> PathItem()))
r.paths.pathItems should contain ("a" -> PathItem())
}

it should "modify a case class with an additional explicit copy" in {
Expand All @@ -42,7 +44,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
val r = f.modify(_.state).setTo("B")
r.state shouldEqual "B"
}

it should "modify a case class with an ambiguous additional explicit copy" in {
Expand All @@ -51,7 +54,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
}

val f = Frozen("A", 0)
f.modify(_.state).setTo("B")
val r = f.modify(_.state).setTo("B")
r.state shouldEqual "B"
}

it should "modify a class with two explicit copy methods" in {
Expand All @@ -61,7 +65,8 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
}

val f = new Frozen("A", 0)
f.modify(_.state).setTo("B")
val r = f.modify(_.state).setTo("B")
r.state shouldEqual "B"
}

it should "modify a case class with an ambiguous additional explicit copy and pick the synthetic one first" in {
Expand All @@ -77,6 +82,19 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
accessed shouldEqual 0
}

it should "not compile when modifying a field which is not present as a copy parameter" in {
"""
case class Content(x: String)

class A(val c: Content) {
def copy(x: String = c.x): A = new A(Content(x))
}

val a = new A(Content("A"))
val am = a.modify(_.c).setTo(Content("B"))
""" shouldNot compile
}

// TODO: Would be nice to be able to handle this case. Based on the types, it
// is obvious, that the explicit copy should be picked, but I'm not sure if we
// can get that information
Expand All @@ -90,5 +108,4 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
// val f = Frozen("A", 0)
// f.modify(_.state).setTo('B')
// }

}
Loading