Skip to content

Commit

Permalink
review fixes 7
Browse files Browse the repository at this point in the history
  • Loading branch information
Łukasz Bigorajski committed Sep 23, 2024
1 parent 762a141 commit d1a67f3
Show file tree
Hide file tree
Showing 32 changed files with 172 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class SpelBenchmarkSetup(expression: String, vars: Map[String, AnyRef]) {
new SimpleDictRegistry(Map.empty),
expressionConfig,
classDefinitionSet = ClassDefinitionTestUtils.createDefinitionForDefaultAdditionalClasses,
evaluator,
ClassExtractionSettings.Default
evaluator
)

private val validationContext = ValidationContext(vars.mapValuesNow(Typed.fromInstance), Map.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class SpelSecurityBenchmarkSetup(expression: String, vars: Map[String, AnyRef])
new SimpleDictRegistry(Map.empty),
expressionDefinition,
classDefinitionSet = ClassDefinitionTestUtils.createDefinitionForDefaultAdditionalClasses,
evaluator,
ClassExtractionSettings.Default
evaluator
)

private val validationContext = ValidationContext(vars.mapValuesNow(Typed.fromInstance), Map.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class ExpressionSuggesterBenchmarkSetup() {
clazzDefinitions,
dictServices,
getClass.getClassLoader,
List.empty,
ClassDefinitionTestUtils.DefaultSettings
List.empty
)

private val variables: Map[String, TypingResult] = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,21 @@ import pl.touk.nussknacker.engine.graph.expression.Expression.Language
import pl.touk.nussknacker.engine.spel.{ExpressionSuggestion, SpelExpressionSuggester}
import pl.touk.nussknacker.engine.variables.GlobalVariablesPreparer
import pl.touk.nussknacker.engine.ModelData
import pl.touk.nussknacker.engine.api.process.ClassExtractionSettings
import pl.touk.nussknacker.engine.definition.clazz.ClassDefinitionSet
import pl.touk.nussknacker.engine.util.CaretPosition2d

import scala.concurrent.{ExecutionContext, Future}

// todo: lbg imports
class ExpressionSuggester(
expressionDefinition: ExpressionConfigDefinition,
classDefinitions: ClassDefinitionSet,
uiDictServices: UiDictServices,
classLoader: ClassLoader,
scenarioPropertiesNames: Iterable[String],
settings: ClassExtractionSettings,
scenarioPropertiesNames: Iterable[String]
) {

private val spelExpressionSuggester =
new SpelExpressionSuggester(expressionDefinition, classDefinitions, uiDictServices, classLoader, settings)
new SpelExpressionSuggester(expressionDefinition, classDefinitions, uiDictServices, classLoader)

private val validationContextGlobalVariablesOnly =
GlobalVariablesPreparer(expressionDefinition).prepareValidationContextWithGlobalVariablesOnly(
Expand Down Expand Up @@ -63,8 +60,7 @@ object ExpressionSuggester {
modelData.modelDefinitionWithClasses.classDefinitions,
modelData.designerDictServices,
modelData.modelClassLoader.classLoader,
scenarioPropertiesNames,
modelData.modelDefinition.settings
scenarioPropertiesNames
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ class ExpressionSuggesterSpec
clazzDefinitions,
dictServices,
getClass.getClassLoader,
List("scenarioProperty"),
ClassDefinitionTestUtils.DefaultSettings
List("scenarioProperty")
)

private val localVariables: Map[String, TypingResult] = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ object TestFlinkProcessCompilerDataFactory {
context.expressionConfig,
context.dictRegistry,
context.classDefinitions,
process.metaData,
modelData.modelDefinition.settings,
process.metaData
),
scenarioTestData
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ class AvroSchemaSpelExpressionSpec extends AnyFunSpec with Matchers {
new SimpleDictRegistry(Map(dictId -> EmbeddedDictDefinition(Map("key1" -> "value1")))),
enableSpelForceCompile = true,
Standard,
ClassDefinitionTestUtils.createDefinitionForClasses(classOf[EnumSymbol]),
ClassDefinitionTestUtils.DefaultSettings
ClassDefinitionTestUtils.createDefinitionForClasses(classOf[EnumSymbol])
)
.parse(expr, validationCtx, Typed.fromDetailedType[T])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ public class NuReflectiveMethodExecutor extends ReflectiveMethodExecutor {

private boolean argumentConversionOccurred = false;

public NuReflectiveMethodExecutor(ReflectiveMethodExecutor original) {
private final ClassLoader classLoader;

public NuReflectiveMethodExecutor(ReflectiveMethodExecutor original, ClassLoader classLoader) {
super(original.getMethod());
this.method = original.getMethod();
if (method.isVarArgs()) {
Expand All @@ -42,6 +44,7 @@ public NuReflectiveMethodExecutor(ReflectiveMethodExecutor original) {
else {
this.varargsPosition = null;
}
this.classLoader = classLoader;
}

/**
Expand Down Expand Up @@ -98,7 +101,7 @@ public TypedValue execute(EvaluationContext context, Object target, Object... ar
}
ReflectionUtils.makeAccessible(this.method);
//Nussknacker: we use custom method invoker which is aware of array conversion
Object value = methodInvoker.invoke(this.method, target, arguments);
Object value = methodInvoker.invoke(this.method, target, arguments, this.classLoader);
return new TypedValue(value, new TypeDescriptor(new MethodParameter(this.method, -1)).narrow(value));
}
catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ private static Method[] concatArrays(Method[] a, Method[] b) {

public static final class ConversionAwareMethodInvoker {

public Object invoke(Method method, Object target, Object[] arguments) throws IllegalAccessException, InvocationTargetException {
public Object invoke(Method method,
Object target,
Object[] arguments,
ClassLoader classLoader) throws IllegalAccessException, InvocationTargetException {
if (target != null && target.getClass().isArray() && method.getDeclaringClass().isAssignableFrom(List.class)) {
return method.invoke(RuntimeConversionHandler.convert(target), arguments);
} else if (ExtensionMethods.applies(method.getDeclaringClass())) {
return ExtensionMethods.invoke(method, target, arguments);
return ExtensionMethods.invoke(method, target, arguments, classLoader);
} else {
return method.invoke(target, arguments);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,35 +41,31 @@ object ExpressionCompiler {
dictRegistry: DictRegistry,
expressionConfig: ExpressionConfigDefinition,
classDefinitionSet: ClassDefinitionSet,
expressionEvaluator: ExpressionEvaluator,
settings: ClassExtractionSettings,
expressionEvaluator: ExpressionEvaluator
): ExpressionCompiler =
default(
loader,
dictRegistry,
expressionConfig,
expressionConfig.optimizeCompilation,
classDefinitionSet,
expressionEvaluator,
settings,
expressionEvaluator
)

def withoutOptimization(
loader: ClassLoader,
dictRegistry: DictRegistry,
expressionConfig: ExpressionConfigDefinition,
classDefinitionSet: ClassDefinitionSet,
expressionEvaluator: ExpressionEvaluator,
settings: ClassExtractionSettings,
expressionEvaluator: ExpressionEvaluator
): ExpressionCompiler =
default(
loader,
dictRegistry,
expressionConfig,
optimizeCompilation = false,
classDefinitionSet,
expressionEvaluator,
settings,
expressionEvaluator
)

def withoutOptimization(modelData: ModelData): ExpressionCompiler = {
Expand All @@ -80,8 +76,7 @@ object ExpressionCompiler {
modelData.modelDefinitionWithClasses.classDefinitions,
ExpressionEvaluator.unOptimizedEvaluator(
GlobalVariablesPreparer(modelData.modelDefinition.expressionConfig)
),
modelData.modelDefinition.settings,
)
)
}

Expand All @@ -91,8 +86,7 @@ object ExpressionCompiler {
expressionConfig: ExpressionConfigDefinition,
optimizeCompilation: Boolean,
classDefinitionSet: ClassDefinitionSet,
expressionEvaluator: ExpressionEvaluator,
settings: ClassExtractionSettings,
expressionEvaluator: ExpressionEvaluator
): ExpressionCompiler = {
def spelParser(flavour: Flavour) =
SpelExpressionParser.default(
Expand All @@ -101,8 +95,7 @@ object ExpressionCompiler {
dictRegistry,
optimizeCompilation,
flavour,
classDefinitionSet,
settings,
classDefinitionSet
)

val defaultParsers =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ object ProcessValidator {
dictRegistry,
modelDefinition.expressionConfig,
definitionWithTypes.classDefinitions,
expressionEvaluator,
definitionWithTypes.modelDefinition.settings
expressionEvaluator
)

val nodeCompiler = new NodeCompiler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ object ProcessCompilerData {
dictRegistry,
definitionWithTypes.modelDefinition.expressionConfig,
definitionWithTypes.classDefinitions,
expressionEvaluator,
definitionWithTypes.modelDefinition.settings
expressionEvaluator
)

// for testing environment it's important to take classloader from user jar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ object ClassDefinitionSet {
}

case class ClassDefinitionSet(classDefinitionsMap: Map[Class[_], ClassDefinition]) {
lazy val unknown = get(classOf[java.lang.Object])

def all: Set[ClassDefinition] = classDefinitionsMap.values.toSet

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package pl.touk.nussknacker.engine.definition.model

import pl.touk.nussknacker.engine.definition.clazz.ClassDefinitionSet
import pl.touk.nussknacker.engine.definition.component.ComponentDefinitionWithImplementation
import pl.touk.nussknacker.engine.extension.ClassDefinitionSetWithExtensionMethods

case class ModelDefinitionWithClasses(modelDefinition: ModelDefinition) {

@transient lazy val classDefinitions: ClassDefinitionSet = ClassDefinitionSet(
ModelClassDefinitionDiscovery.discoverClasses(modelDefinition)
)
@transient lazy val classDefinitions: ClassDefinitionSet = new ClassDefinitionSetWithExtensionMethods(
ClassDefinitionSet(ModelClassDefinitionDiscovery.discoverClasses(modelDefinition)),
modelDefinition.settings
).value

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package pl.touk.nussknacker.engine.extension

import pl.touk.nussknacker.engine.api.typed.typing.TypingResult
import pl.touk.nussknacker.engine.definition.clazz.ClassDefinitionSet

import scala.util.Try

final case class AllowedClasses(namesWithTyping: Map[String, TypingResult]) {
def get(className: String): Option[TypingResult] =
namesWithTyping.get(className)
}

object AllowedClasses {

def apply(set: ClassDefinitionSet): AllowedClasses =
new AllowedClasses(
namesWithTyping = set.classDefinitionsMap
.map { case (clazz, classDefinition) =>
clazz.getName -> Try(classDefinition.clazzName).toOption
}
.collect { case (className: String, Some(t)) =>
className -> t
}
.filterNot(e => isScalaObject(e._1))
)

private def isScalaObject(className: String): Boolean =
className.contains("$")
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,55 @@ import cats.implicits.catsSyntaxValidatedId
import pl.touk.nussknacker.engine.api.Documentation
import pl.touk.nussknacker.engine.api.generics.GenericFunctionTypingError
import pl.touk.nussknacker.engine.api.typed.typing
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue, TypingResult}
import pl.touk.nussknacker.engine.api.typed.typing.{Typed, TypedObjectWithValue}

import scala.util.Try

@SkipAutoDiscovery
sealed trait Cast {

@Documentation(description = "Checks if a type can be casted to a given class")
def canCastTo(clazzType: String): Boolean
def canCastTo(className: String): Boolean

@Documentation(description = "Casts a type to a given class or throws exception if type cannot be casted.")
def castTo[T](clazzType: String): T
def castTo[T](className: String): T

}

object CastTyping {

def castToTyping(allowedClassNamesWithTyping: Map[String, TypingResult])(
def castToTyping(allowedClasses: AllowedClasses)(
instanceType: typing.TypingResult,
arguments: List[typing.TypingResult]
): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] = arguments match {
case TypedObjectWithValue(_, clazzName: String) :: Nil =>
allowedClassNamesWithTyping.get(clazzName) match {
allowedClasses.get(clazzName) match {
case Some(typing) => typing.validNel
case None => GenericFunctionTypingError.OtherError(s"$clazzName is not allowed").invalidNel
}
case _ => GenericFunctionTypingError.ArgumentTypeError.invalidNel
}

def canCastToTyping(allowedClassNamesWithTyping: Map[String, TypingResult])(
def canCastToTyping(allowedClasses: AllowedClasses)(
instanceType: typing.TypingResult,
arguments: List[typing.TypingResult]
): ValidatedNel[GenericFunctionTypingError, typing.TypingResult] =
castToTyping(allowedClassNamesWithTyping)(instanceType, arguments).map(_ => Typed.typedClass[Boolean])
castToTyping(allowedClasses)(instanceType, arguments).map(_ => Typed.typedClass[Boolean])

}

class CastImpl(target: Any) extends Cast {
class CastImpl(target: Any, classLoader: ClassLoader) extends Cast {

override def canCastTo(clazzType: String): Boolean =
Class.forName(clazzType).isAssignableFrom(target.getClass)
override def canCastTo(className: String): Boolean =
classLoader.loadClass(className).isAssignableFrom(target.getClass)

override def castTo[T](clazzType: String): T = Try {
val clazz = Class.forName(clazzType)
override def castTo[T](className: String): T = Try {
val clazz = classLoader.loadClass(className)
if (clazz.isInstance(target)) {
clazz.cast(target).asInstanceOf[T]
} else {
throw new ClassCastException(s"Cannot cast: ${target.getClass} to: $clazzType")
throw new ClassCastException(s"Cannot cast: ${target.getClass} to: $className")
}
}.get

}

object CastImpl {

def apply(target: Any): Cast =
new CastImpl(target)
}
Loading

0 comments on commit d1a67f3

Please sign in to comment.