From 4190e7615af8137324da16326a7328f574e1d38e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Bigorajski?= Date: Fri, 20 Sep 2024 11:01:45 +0200 Subject: [PATCH] review fixes 4 --- .../engine/extension/SkipAutoDiscovery.java | 11 +++ .../clazz/ClassDefinitionExtractor.scala | 4 +- .../nussknacker/engine/extension/Cast.scala | 37 ++++---- ...ssDefinitionSetExtensionMethodsAware.scala | 92 ++++++++++++++++--- .../engine/spel/SpelExpressionSpec.scala | 17 +++- 5 files changed, 125 insertions(+), 36 deletions(-) create mode 100644 extensions-api/src/main/scala/pl/touk/nussknacker/engine/extension/SkipAutoDiscovery.java diff --git a/extensions-api/src/main/scala/pl/touk/nussknacker/engine/extension/SkipAutoDiscovery.java b/extensions-api/src/main/scala/pl/touk/nussknacker/engine/extension/SkipAutoDiscovery.java new file mode 100644 index 00000000000..83522356bab --- /dev/null +++ b/extensions-api/src/main/scala/pl/touk/nussknacker/engine/extension/SkipAutoDiscovery.java @@ -0,0 +1,11 @@ +package pl.touk.nussknacker.engine.extension; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface SkipAutoDiscovery { +} diff --git a/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/definition/clazz/ClassDefinitionExtractor.scala b/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/definition/clazz/ClassDefinitionExtractor.scala index 0bed06be8d0..d27866184d6 100644 --- a/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/definition/clazz/ClassDefinitionExtractor.scala +++ b/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/definition/clazz/ClassDefinitionExtractor.scala @@ -58,9 +58,7 @@ class ClassDefinitionExtractor(settings: ClassExtractionSettings) extends LazyLo ): Map[String, List[MethodDefinition]] = { val membersPredicate = settings.visibleMembersPredicate(clazz) val methods = extractPublicMethods(clazz, membersPredicate, staticMethodsAndFields) -// val extensionMethods = extractExtensionMethods(membersPredicate, staticMethodsAndFields) - val fields = extractPublicFields(clazz, membersPredicate, staticMethodsAndFields).mapValuesNow(List(_)) -// filterHiddenParameterAndReturnType(methods ++ fields ++ extensionMethods) + val fields = extractPublicFields(clazz, membersPredicate, staticMethodsAndFields).mapValuesNow(List(_)) methods ++ fields } diff --git a/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/Cast.scala b/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/Cast.scala index c97dae808c9..b8c184e7061 100644 --- a/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/Cast.scala +++ b/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/Cast.scala @@ -3,40 +3,43 @@ package pl.touk.nussknacker.engine.extension import cats.data.ValidatedNel import cats.implicits.catsSyntaxValidatedId import pl.touk.nussknacker.engine.api.Documentation -import pl.touk.nussknacker.engine.api.generics.{GenericFunctionTypingError, GenericType, TypingFunction} +import pl.touk.nussknacker.engine.api.generics.GenericFunctionTypingError import pl.touk.nussknacker.engine.api.typed.typing -import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue} +import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue, TypingResult} -import scala.util.{Failure, Success, Try} +import scala.util.Try +@SkipAutoDiscovery sealed trait Cast { @Documentation(description = "Checks if a type can be casted to a given class") def canCastTo(clazzType: String): Boolean @Documentation(description = "Casts a type to a given class or throws exception if type cannot be casted.") - @GenericType(typingFunction = classOf[CastTyping.Typing]) def castTo[T](clazzType: String): T } object CastTyping { - class Typing extends TypingFunction { - - override def computeResultType( - arguments: List[typing.TypingResult] - ): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] = arguments match { - case TypedObjectWithValue(_, clazzName: String) :: Nil => - Try(Class.forName(clazzName)) match { - case Success(clazz) => Typed.typedClass(clazz).validNel - case Failure(_) => GenericFunctionTypingError.ArgumentTypeError.invalidNel - } - case _ => GenericFunctionTypingError.ArgumentTypeError.invalidNel - } - + def castToTyping(allowedClassNamesWithTyping: Map[String, TypingResult])( + instanceType: typing.TypingResult, + arguments: List[typing.TypingResult] + ): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] = arguments match { + case TypedObjectWithValue(_, clazzName: String) :: Nil => + allowedClassNamesWithTyping.get(clazzName) match { + case Some(typing) => typing.validNel + case None => GenericFunctionTypingError.OtherError(s"$clazzName is not allowed").invalidNel + } + case _ => GenericFunctionTypingError.ArgumentTypeError.invalidNel } + def canCastToTyping(allowedClassNamesWithTyping: Map[String, TypingResult])( + instanceType: typing.TypingResult, + arguments: List[typing.TypingResult] + ): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] = + castToTyping(allowedClassNamesWithTyping)(instanceType, arguments).map(_ => Typed.typedClass[Boolean]) + } class CastImpl(target: Any) extends Cast { diff --git a/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/ClassDefinitionSetExtensionMethodsAware.scala b/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/ClassDefinitionSetExtensionMethodsAware.scala index dce2b4b1299..3dc79693994 100644 --- a/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/ClassDefinitionSetExtensionMethodsAware.scala +++ b/scenario-compiler/src/main/scala/pl/touk/nussknacker/engine/extension/ClassDefinitionSetExtensionMethodsAware.scala @@ -1,7 +1,11 @@ package pl.touk.nussknacker.engine.extension +import cats.data.ValidatedNel import pl.touk.nussknacker.engine.ModelData +import pl.touk.nussknacker.engine.api.Documentation +import pl.touk.nussknacker.engine.api.generics.{GenericFunctionTypingError, MethodTypeInfo, Parameter} import pl.touk.nussknacker.engine.api.process.ClassExtractionSettings +import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypingResult, Unknown} import pl.touk.nussknacker.engine.definition.clazz.ClassDefinitionExtractor.{ MethodDefinitionsExtension, MethodExtensions @@ -10,17 +14,24 @@ import pl.touk.nussknacker.engine.definition.clazz.{ ClassDefinition, ClassDefinitionExtractor, ClassDefinitionSet, + FunctionalMethodDefinition, MethodDefinition } +import pl.touk.nussknacker.engine.extension.ClassDefinitionSetExtensionMethodsAware.{ + createStaticallyDefinedMethodsMap, + extractExtensionMethods +} import pl.touk.nussknacker.engine.util.Implicits.{RichScalaMap, RichTupleList} import java.lang.reflect.{Method, Modifier} final case class ClassDefinitionSetExtensionMethodsAware(set: ClassDefinitionSet, settings: ClassExtractionSettings) { - private lazy val extractor = new ClassDefinitionExtractor(settings) - private lazy val extensionMethodsMap = extractExtensionMethods() + private lazy val extractor: ClassDefinitionExtractor = new ClassDefinitionExtractor(settings) + // We cannot have a class as the key because of `visibleMembersPredicate` e.g. Cast.castTo may be accessible for many classes. + private lazy val extensionMethodsMap: Map[Method, Map[String, List[MethodDefinition]]] = + extractExtensionMethods(settings, extractor) ++ createStaticallyDefinedMethodsMap(set) - val unknown: Option[ClassDefinition] = + lazy val unknown: Option[ClassDefinition] = getWithExtensionMethods(classOf[Any]) def get(clazz: Class[_]): Option[ClassDefinition] = @@ -43,20 +54,43 @@ final case class ClassDefinitionSetExtensionMethodsAware(set: ClassDefinitionSet .flatMap(_._2) } - private def extractExtensionMethods(): Map[Method, Map[String, List[MethodDefinition]]] = { +} + +object ClassDefinitionSetExtensionMethodsAware { + private val stringClass: Class[String] = classOf[String] + private val stringTyping: TypingResult = Typed.genericTypeClass(stringClass, Nil) + + def apply(modelData: ModelData): ClassDefinitionSetExtensionMethodsAware = + new ClassDefinitionSetExtensionMethodsAware( + modelData.modelDefinitionWithClasses.classDefinitions, + modelData.modelDefinition.settings + ) + + private def extractExtensionMethods( + settings: ClassExtractionSettings, + extractor: ClassDefinitionExtractor + ): Map[Method, Map[String, List[MethodDefinition]]] = { ExtensionMethods.registry - .flatMap(extractMethodsWithDefinitions) + .filter(filterAnnotatedClass) + .flatMap(clazz => extractMethodsWithDefinitions(extractor, clazz)) .groupBy(_._1) - .mapValuesNow(filterByVisibilityOfParams) + .mapValuesNow(definitionsSet => filterByVisibilityOfParams(settings, definitionsSet)) } - private def extractMethodsWithDefinitions(clazz: Class[_]): List[(Method, List[(String, MethodDefinition)])] = + private def filterAnnotatedClass(clazz: Class[_]): Boolean = + Option(clazz.getAnnotation(classOf[SkipAutoDiscovery])).isEmpty + + private def extractMethodsWithDefinitions( + extractor: ClassDefinitionExtractor, + clazz: Class[_] + ): List[(Method, List[(String, MethodDefinition)])] = clazz.getMethods.toList .filter(m => !Modifier.isStatic(m.getModifiers)) .filter(_.javaVersionOfVarArgMethod().isEmpty) .map(m => m -> extractor.extractMethod(clazz, m)) private def filterByVisibilityOfParams( + settings: ClassExtractionSettings, methodsWithDefinitions: Set[(Method, List[(String, MethodDefinition)])] ): Map[String, List[MethodDefinition]] = methodsWithDefinitions @@ -65,14 +99,44 @@ final case class ClassDefinitionSetExtensionMethodsAware(set: ClassDefinitionSet .toGroupedMap .filterHiddenParameterAndReturnType(settings) -} - -object ClassDefinitionSetExtensionMethodsAware { + private def createStaticallyDefinedMethodsMap( + set: ClassDefinitionSet + ): Map[Method, Map[String, List[MethodDefinition]]] = { + val allowedClasses = set.classDefinitionsMap.map(e => e._1.getName -> e._2.clazzName) + val castClass = classOf[Cast] + val canCastToMethod = castClass.getDeclaredMethod("canCastTo", stringClass) + val castToMethod = castClass.getDeclaredMethod("castTo", stringClass) + Map( + canCastToMethod -> Map( + canCastToMethod.getName -> List( + castFunctionalMethodDefinition(canCastToMethod, CastTyping.canCastToTyping(allowedClasses)) + ) + ), + castToMethod -> Map( + castToMethod.getName -> List( + castFunctionalMethodDefinition(castToMethod, CastTyping.castToTyping(allowedClasses)) + ) + ) + ) + } - def apply(modelData: ModelData): ClassDefinitionSetExtensionMethodsAware = - new ClassDefinitionSetExtensionMethodsAware( - modelData.modelDefinitionWithClasses.classDefinitions, - modelData.modelDefinition.settings + private def castFunctionalMethodDefinition( + method: Method, + typeFunction: (TypingResult, List[TypingResult]) => ValidatedNel[GenericFunctionTypingError, TypingResult] + ): FunctionalMethodDefinition = + FunctionalMethodDefinition( + typeFunction = typeFunction, + signature = MethodTypeInfo( + noVarArgs = List( + Parameter("clazzType", stringTyping) + ), + varArg = None, + result = Unknown + ), + name = method.getName, + description = getDocumentationAnnotationValue(method) ) + private def getDocumentationAnnotationValue(method: Method): Option[String] = + Option(method.getAnnotation(classOf[Documentation])).map(_.description()) } diff --git a/scenario-compiler/src/test/scala/pl/touk/nussknacker/engine/spel/SpelExpressionSpec.scala b/scenario-compiler/src/test/scala/pl/touk/nussknacker/engine/spel/SpelExpressionSpec.scala index 0f726833b2a..1aea6718671 100644 --- a/scenario-compiler/src/test/scala/pl/touk/nussknacker/engine/spel/SpelExpressionSpec.scala +++ b/scenario-compiler/src/test/scala/pl/touk/nussknacker/engine/spel/SpelExpressionSpec.scala @@ -43,7 +43,11 @@ import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.MissingObjectErr } import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.OperatorError._ import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.UnsupportedOperationError.ArrayConstructorError -import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.{ArgumentTypeError, ExpressionTypeError} +import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.{ + ArgumentTypeError, + ExpressionTypeError, + GenericFunctionError +} import pl.touk.nussknacker.engine.spel.SpelExpressionParser.{Flavour, Standard} import pl.touk.nussknacker.engine.testing.ModelDefinitionBuilder import pl.touk.nussknacker.test.ValidatedValuesDetailedMessage @@ -1394,7 +1398,7 @@ class SpelExpressionSpec extends AnyFunSuite with Matchers with ValidatedValuesD test("should return an error if the cast return type cannot be determined at parse time") { parse[Any]("{11, 12}.castTo('java.util.XYZ')", ctx).invalidValue.toList should matchPattern { - case ArgumentTypeError("castTo", _, _) :: Nil => + case GenericFunctionError("java.util.XYZ is not allowed") :: Nil => } parse[Any]("{11, 12}.castTo(#obj.id)", ctx).invalidValue.toList should matchPattern { case ArgumentTypeError("castTo", _, _) :: Nil => @@ -1478,6 +1482,15 @@ class SpelExpressionSpec extends AnyFunSuite with Matchers with ValidatedValuesD } } + test("should not allow cast to disallowed classes") { + parse[Any]( + "#hashMap.castTo('java.util.HashMap').remove('testKey')", + ctx.withVariable("hashMap", new java.util.HashMap[String, Int](Map("testKey" -> 2).asJava)) + ).invalidValue.toList should matchPattern { + case GenericFunctionError("java.util.HashMap is not allowed") :: IllegalInvocationError(Unknown) :: Nil => + } + } + } case class SampleObject(list: java.util.List[SampleValue])