Skip to content

Commit

Permalink
Improvement: FlinkBaseTypeInfoRegister mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
lciolecki committed Oct 28, 2024
1 parent a935294 commit 88224db
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package pl.touk.nussknacker.engine.flink.api.typeinformation

import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.{TypeInfoFactory, TypeInformation, Types}
import org.apache.flink.api.java.typeutils.TypeExtractor

import java.lang.reflect.Type
import java.util
import scala.reflect.{ClassTag, classTag}

object FlinkBaseTypeInfoRegister {

private val BaseTypes = List(
Types.LOCAL_DATE,
Types.LOCAL_TIME,
Types.LOCAL_DATE_TIME,
Types.INSTANT,
Types.SQL_DATE,
Types.SQL_TIME,
Types.SQL_TIMESTAMP,
)

def makeSureBaseTypesAreRegistered(): Unit =
BaseTypes.foreach(register(_))

private def register[T: ClassTag](typeInformation: TypeInformation[T]): Unit = {
val klass = classTag[T].runtimeClass.asInstanceOf[Class[T]]
val factory = new TypeInfoFactory[T] {
override def createTypeInfo(
t: Type,
genericParameters: util.Map[String, TypeInformation[_]]
): TypeInformation[T] =
typeInformation
}

try {
TypeExtractor.registerFactory(klass, factory.getClass)
} catch {
case exc: InvalidTypesException
if exc.getMessage == s"A TypeInfoFactory for type '$klass' is already registered." =>
// Factory has been already registered
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ object TypeInformationDetection {
// We use SPI to provide implementation of TypeInformationDetection because we don't want to make
// implementation classes available in flink-components-api module.
val instance: TypeInformationDetection = {

FlinkBaseTypeInfoRegister.makeSureBaseTypesAreRegistered()

val classloader = Thread.currentThread().getContextClassLoader
ServiceLoader
.load(classOf[TypeInformationDetection], classloader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import pl.touk.nussknacker.engine.ModelData
import pl.touk.nussknacker.engine.api.namespaces.NamingStrategy
import pl.touk.nussknacker.engine.api.{JobData, ProcessVersion}
import pl.touk.nussknacker.engine.deployment.DeploymentData
import pl.touk.nussknacker.engine.flink.api.typeinformation.FlinkBaseTypeInfoRegister
import pl.touk.nussknacker.engine.flink.api.{NamespaceMetricsTags, NkGlobalParameters}
import pl.touk.nussknacker.engine.process.util.Serializers

Expand Down Expand Up @@ -100,6 +101,7 @@ object ExecutionConfigPreparer extends LazyLogging {
override def prepareExecutionConfig(
config: ExecutionConfig
)(jobData: JobData, deploymentData: DeploymentData): Unit = {
FlinkBaseTypeInfoRegister.makeSureBaseTypesAreRegistered()
Serializers.registerSerializers(modelData, config)
if (enableObjectReuse) {
config.enableObjectReuse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,6 @@ import pl.touk.nussknacker.engine.util.Implicits._
*/
class TypingResultAwareTypeInformationDetection extends TypeInformationDetection {

private val registeredTypeInfos: Map[TypedClass, TypeInformation[_]] = Map(
Typed.typedClass[String] -> Types.STRING,
Typed.typedClass[Boolean] -> Types.BOOLEAN,
Typed.typedClass[Byte] -> Types.BYTE,
Typed.typedClass[Short] -> Types.SHORT,
Typed.typedClass[Integer] -> Types.INT,
Typed.typedClass[Long] -> Types.LONG,
Typed.typedClass[Float] -> Types.FLOAT,
Typed.typedClass[Double] -> Types.DOUBLE,
Typed.typedClass[Character] -> Types.CHAR,
Typed.typedClass[java.math.BigDecimal] -> Types.BIG_DEC,
Typed.typedClass[java.math.BigInteger] -> Types.BIG_INT,
Typed.typedClass[java.time.LocalDate] -> Types.LOCAL_DATE,
Typed.typedClass[java.time.LocalTime] -> Types.LOCAL_TIME,
Typed.typedClass[java.time.LocalDateTime] -> Types.LOCAL_DATE_TIME,
Typed.typedClass[java.time.Instant] -> Types.INSTANT,
Typed.typedClass[java.sql.Date] -> Types.SQL_DATE,
Typed.typedClass[java.sql.Time] -> Types.SQL_TIME,
Typed.typedClass[java.sql.Timestamp] -> Types.SQL_TIMESTAMP,
)

def forContext(validationContext: ValidationContext): TypeInformation[Context] = {
val variables = forType(
Typed.record(validationContext.localVariables, Typed.typedClass[Map[String, AnyRef]])
Expand Down Expand Up @@ -82,8 +61,6 @@ class TypingResultAwareTypeInformationDetection extends TypeInformationDetection
// We generally don't use scala Maps in our runtime, but it is useful for some internal type infos: TODO move it somewhere else
case a: TypedObjectTypingResult if a.runtimeObjType.klass == classOf[Map[String, _]] =>
createScalaMapTypeInformation(a)
case a: SingleTypingResult if registeredTypeInfos.contains(a.runtimeObjType) =>
registeredTypeInfos(a.runtimeObjType)
// TODO: scala case classes are not handled nicely here... CaseClassTypeInfo is created only via macro, here Kryo is used
case a: SingleTypingResult if a.runtimeObjType.params.isEmpty =>
TypeInformation.of(a.runtimeObjType.klass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import pl.touk.nussknacker.engine.api.context.ValidationContext
import pl.touk.nussknacker.engine.api.typed.typing.Typed
import pl.touk.nussknacker.engine.api.{Context, ValueWithContext}
import pl.touk.nussknacker.engine.flink.api.typeinfo.caseclass.ScalaCaseClassSerializer
import pl.touk.nussknacker.engine.flink.api.typeinformation.FlinkBaseTypeInfoRegister
import pl.touk.nussknacker.engine.flink.serialization.FlinkTypeInformationSerializationMixin
import pl.touk.nussknacker.engine.process.typeinformation.internal.typedobject._
import pl.touk.nussknacker.engine.process.typeinformation.testTypedObject.CustomTypedObject
Expand Down

0 comments on commit 88224db

Please sign in to comment.