Skip to content

Commit

Permalink
review fixes 4
Browse files Browse the repository at this point in the history
  • Loading branch information
Łukasz Bigorajski committed Sep 20, 2024
1 parent 1cca975 commit 4190e76
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package pl.touk.nussknacker.engine.extension;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SkipAutoDiscovery {
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ class ClassDefinitionExtractor(settings: ClassExtractionSettings) extends LazyLo
): Map[String, List[MethodDefinition]] = {
val membersPredicate = settings.visibleMembersPredicate(clazz)
val methods = extractPublicMethods(clazz, membersPredicate, staticMethodsAndFields)
// val extensionMethods = extractExtensionMethods(membersPredicate, staticMethodsAndFields)
val fields = extractPublicFields(clazz, membersPredicate, staticMethodsAndFields).mapValuesNow(List(_))
// filterHiddenParameterAndReturnType(methods ++ fields ++ extensionMethods)
val fields = extractPublicFields(clazz, membersPredicate, staticMethodsAndFields).mapValuesNow(List(_))
methods ++ fields
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,43 @@ package pl.touk.nussknacker.engine.extension
import cats.data.ValidatedNel
import cats.implicits.catsSyntaxValidatedId
import pl.touk.nussknacker.engine.api.Documentation
import pl.touk.nussknacker.engine.api.generics.{GenericFunctionTypingError, GenericType, TypingFunction}
import pl.touk.nussknacker.engine.api.generics.GenericFunctionTypingError
import pl.touk.nussknacker.engine.api.typed.typing
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue}
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue, TypingResult}

import scala.util.{Failure, Success, Try}
import scala.util.Try

@SkipAutoDiscovery
sealed trait Cast {

@Documentation(description = "Checks if a type can be casted to a given class")
def canCastTo(clazzType: String): Boolean

@Documentation(description = "Casts a type to a given class or throws exception if type cannot be casted.")
@GenericType(typingFunction = classOf[CastTyping.Typing])
def castTo[T](clazzType: String): T

}

object CastTyping {

class Typing extends TypingFunction {

override def computeResultType(
arguments: List[typing.TypingResult]
): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] = arguments match {
case TypedObjectWithValue(_, clazzName: String) :: Nil =>
Try(Class.forName(clazzName)) match {
case Success(clazz) => Typed.typedClass(clazz).validNel
case Failure(_) => GenericFunctionTypingError.ArgumentTypeError.invalidNel
}
case _ => GenericFunctionTypingError.ArgumentTypeError.invalidNel
}

def castToTyping(allowedClassNamesWithTyping: Map[String, TypingResult])(
instanceType: typing.TypingResult,
arguments: List[typing.TypingResult]
): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] = arguments match {
case TypedObjectWithValue(_, clazzName: String) :: Nil =>
allowedClassNamesWithTyping.get(clazzName) match {
case Some(typing) => typing.validNel
case None => GenericFunctionTypingError.OtherError(s"$clazzName is not allowed").invalidNel
}
case _ => GenericFunctionTypingError.ArgumentTypeError.invalidNel
}

def canCastToTyping(allowedClassNamesWithTyping: Map[String, TypingResult])(
instanceType: typing.TypingResult,
arguments: List[typing.TypingResult]
): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] =
castToTyping(allowedClassNamesWithTyping)(instanceType, arguments).map(_ => Typed.typedClass[Boolean])

}

class CastImpl(target: Any) extends Cast {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package pl.touk.nussknacker.engine.extension

import cats.data.ValidatedNel
import pl.touk.nussknacker.engine.ModelData
import pl.touk.nussknacker.engine.api.Documentation
import pl.touk.nussknacker.engine.api.generics.{GenericFunctionTypingError, MethodTypeInfo, Parameter}
import pl.touk.nussknacker.engine.api.process.ClassExtractionSettings
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypingResult, Unknown}
import pl.touk.nussknacker.engine.definition.clazz.ClassDefinitionExtractor.{
MethodDefinitionsExtension,
MethodExtensions
Expand All @@ -10,17 +14,24 @@ import pl.touk.nussknacker.engine.definition.clazz.{
ClassDefinition,
ClassDefinitionExtractor,
ClassDefinitionSet,
FunctionalMethodDefinition,
MethodDefinition
}
import pl.touk.nussknacker.engine.extension.ClassDefinitionSetExtensionMethodsAware.{
createStaticallyDefinedMethodsMap,
extractExtensionMethods
}
import pl.touk.nussknacker.engine.util.Implicits.{RichScalaMap, RichTupleList}

import java.lang.reflect.{Method, Modifier}

final case class ClassDefinitionSetExtensionMethodsAware(set: ClassDefinitionSet, settings: ClassExtractionSettings) {
private lazy val extractor = new ClassDefinitionExtractor(settings)
private lazy val extensionMethodsMap = extractExtensionMethods()
private lazy val extractor: ClassDefinitionExtractor = new ClassDefinitionExtractor(settings)
// We cannot have a class as the key because of `visibleMembersPredicate` e.g. Cast.castTo may be accessible for many classes.
private lazy val extensionMethodsMap: Map[Method, Map[String, List[MethodDefinition]]] =
extractExtensionMethods(settings, extractor) ++ createStaticallyDefinedMethodsMap(set)

val unknown: Option[ClassDefinition] =
lazy val unknown: Option[ClassDefinition] =
getWithExtensionMethods(classOf[Any])

def get(clazz: Class[_]): Option[ClassDefinition] =
Expand All @@ -43,20 +54,43 @@ final case class ClassDefinitionSetExtensionMethodsAware(set: ClassDefinitionSet
.flatMap(_._2)
}

private def extractExtensionMethods(): Map[Method, Map[String, List[MethodDefinition]]] = {
}

object ClassDefinitionSetExtensionMethodsAware {
private val stringClass: Class[String] = classOf[String]
private val stringTyping: TypingResult = Typed.genericTypeClass(stringClass, Nil)

def apply(modelData: ModelData): ClassDefinitionSetExtensionMethodsAware =
new ClassDefinitionSetExtensionMethodsAware(
modelData.modelDefinitionWithClasses.classDefinitions,
modelData.modelDefinition.settings
)

private def extractExtensionMethods(
settings: ClassExtractionSettings,
extractor: ClassDefinitionExtractor
): Map[Method, Map[String, List[MethodDefinition]]] = {
ExtensionMethods.registry
.flatMap(extractMethodsWithDefinitions)
.filter(filterAnnotatedClass)
.flatMap(clazz => extractMethodsWithDefinitions(extractor, clazz))
.groupBy(_._1)
.mapValuesNow(filterByVisibilityOfParams)
.mapValuesNow(definitionsSet => filterByVisibilityOfParams(settings, definitionsSet))
}

private def extractMethodsWithDefinitions(clazz: Class[_]): List[(Method, List[(String, MethodDefinition)])] =
private def filterAnnotatedClass(clazz: Class[_]): Boolean =
Option(clazz.getAnnotation(classOf[SkipAutoDiscovery])).isEmpty

private def extractMethodsWithDefinitions(
extractor: ClassDefinitionExtractor,
clazz: Class[_]
): List[(Method, List[(String, MethodDefinition)])] =
clazz.getMethods.toList
.filter(m => !Modifier.isStatic(m.getModifiers))
.filter(_.javaVersionOfVarArgMethod().isEmpty)
.map(m => m -> extractor.extractMethod(clazz, m))

private def filterByVisibilityOfParams(
settings: ClassExtractionSettings,
methodsWithDefinitions: Set[(Method, List[(String, MethodDefinition)])]
): Map[String, List[MethodDefinition]] =
methodsWithDefinitions
Expand All @@ -65,14 +99,44 @@ final case class ClassDefinitionSetExtensionMethodsAware(set: ClassDefinitionSet
.toGroupedMap
.filterHiddenParameterAndReturnType(settings)

}

object ClassDefinitionSetExtensionMethodsAware {
private def createStaticallyDefinedMethodsMap(
set: ClassDefinitionSet
): Map[Method, Map[String, List[MethodDefinition]]] = {
val allowedClasses = set.classDefinitionsMap.map(e => e._1.getName -> e._2.clazzName)
val castClass = classOf[Cast]
val canCastToMethod = castClass.getDeclaredMethod("canCastTo", stringClass)
val castToMethod = castClass.getDeclaredMethod("castTo", stringClass)
Map(
canCastToMethod -> Map(
canCastToMethod.getName -> List(
castFunctionalMethodDefinition(canCastToMethod, CastTyping.canCastToTyping(allowedClasses))
)
),
castToMethod -> Map(
castToMethod.getName -> List(
castFunctionalMethodDefinition(castToMethod, CastTyping.castToTyping(allowedClasses))
)
)
)
}

def apply(modelData: ModelData): ClassDefinitionSetExtensionMethodsAware =
new ClassDefinitionSetExtensionMethodsAware(
modelData.modelDefinitionWithClasses.classDefinitions,
modelData.modelDefinition.settings
private def castFunctionalMethodDefinition(
method: Method,
typeFunction: (TypingResult, List[TypingResult]) => ValidatedNel[GenericFunctionTypingError, TypingResult]
): FunctionalMethodDefinition =
FunctionalMethodDefinition(
typeFunction = typeFunction,
signature = MethodTypeInfo(
noVarArgs = List(
Parameter("clazzType", stringTyping)
),
varArg = None,
result = Unknown
),
name = method.getName,
description = getDocumentationAnnotationValue(method)
)

private def getDocumentationAnnotationValue(method: Method): Option[String] =
Option(method.getAnnotation(classOf[Documentation])).map(_.description())
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.MissingObjectErr
}
import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.OperatorError._
import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.UnsupportedOperationError.ArrayConstructorError
import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.{ArgumentTypeError, ExpressionTypeError}
import pl.touk.nussknacker.engine.spel.SpelExpressionParseError.{
ArgumentTypeError,
ExpressionTypeError,
GenericFunctionError
}
import pl.touk.nussknacker.engine.spel.SpelExpressionParser.{Flavour, Standard}
import pl.touk.nussknacker.engine.testing.ModelDefinitionBuilder
import pl.touk.nussknacker.test.ValidatedValuesDetailedMessage
Expand Down Expand Up @@ -1394,7 +1398,7 @@ class SpelExpressionSpec extends AnyFunSuite with Matchers with ValidatedValuesD

test("should return an error if the cast return type cannot be determined at parse time") {
parse[Any]("{11, 12}.castTo('java.util.XYZ')", ctx).invalidValue.toList should matchPattern {
case ArgumentTypeError("castTo", _, _) :: Nil =>
case GenericFunctionError("java.util.XYZ is not allowed") :: Nil =>
}
parse[Any]("{11, 12}.castTo(#obj.id)", ctx).invalidValue.toList should matchPattern {
case ArgumentTypeError("castTo", _, _) :: Nil =>
Expand Down Expand Up @@ -1478,6 +1482,15 @@ class SpelExpressionSpec extends AnyFunSuite with Matchers with ValidatedValuesD
}
}

test("should not allow cast to disallowed classes") {
parse[Any](
"#hashMap.castTo('java.util.HashMap').remove('testKey')",
ctx.withVariable("hashMap", new java.util.HashMap[String, Int](Map("testKey" -> 2).asJava))
).invalidValue.toList should matchPattern {
case GenericFunctionError("java.util.HashMap is not allowed") :: IllegalInvocationError(Unknown) :: Nil =>
}
}

}

case class SampleObject(list: java.util.List[SampleValue])
Expand Down

0 comments on commit 4190e76

Please sign in to comment.