Skip to content

Commit

Permalink
Nested collections support (#10)
Browse files Browse the repository at this point in the history
* Support nested collections (List<List<...>>, Array<Set<...>>...)

* Add documentation
  • Loading branch information
natario1 authored Aug 20, 2024
1 parent a7843f1 commit 1687f1d
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 59 deletions.
5 changes: 4 additions & 1 deletion docs/features/builtin-types.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ For example, since `Int` is serializable, Knee can also serialize:
- `List<Int>`
- `Set<Int>`

As you may, the performance of these options is not the same because the `IntArray` signature avoids boxing.
As you may know, the performance of these options is not the same because the `IntArray` signature avoids boxing.

> In case of non-primitive values, `Array<Type>` will be used. That may still perform better than `List` or `Set`,
> although not dramatically better.
Note that since you can serialize collections of any serializable type, and collections themselves are serializable,
nested types are supported. For example, you may pass things like `List<Set<Float>>`, `Set<IntArray>` or `Array<Array<MyObject>>`.
1 change: 0 additions & 1 deletion knee-compiler-plugin/src/main/kotlin/DownwardFunctions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ private fun KneeDownwardFunction.makeIr(context: KneeContext, signature: Downwar
+when (val type = signature.result.encodedType) {
is JniType.Void -> irUnit()
is JniType.Object -> irNull()
is JniType.Array -> irNull()
is JniType.Int -> irInt(0)
is JniType.Long -> irLong(0)
is JniType.Float -> IrConstImpl.float(startOffset, endOffset, type.kn, 0F)
Expand Down
17 changes: 8 additions & 9 deletions knee-compiler-plugin/src/main/kotlin/codec/CollectionCodec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fun Codec.collectionCodecs(
* - jni type is jobject <-> kotlin.String [StringCodec.encodedType]
* - local codegen type is kotlin.String [StringCodec.localCodegenType]
*
* First of all, the jni representation of a collection of strings is [JniType.Array].
* First of all, the jni representation of a collection of strings is [JniType.Object].
* This is determined automatically by [JniType.Object.array].
*
* Then the mapper must decode a jobjectArray into a List/Set/Array/Sequence of strings. This
Expand All @@ -59,27 +59,26 @@ fun Codec.collectionCodecs(
// TODO: revisit - when elementCodec does transform, we create a new instance of transforming helper at every encode decode!
// TODO: also for a function say foo(List<Foo>): List<Foo>, we create it twice, one for the param and one for return
// TODO: wrap in KneeMapper instead of using withCollectionCodecs()
class CollectionCodec constructor(
class CollectionCodec(
private val context: KneeContext,
private val elementCodec: Codec,
private val collectionKind: CollectionKind
) : Codec(
localIrType = collectionKind.getCollectionTypeOf(elementCodec.localIrType, context.symbols),
localCodegenType = collectionKind.getCollectionTypeOf(elementCodec.localCodegenType, context.symbols),
encodedType = when (val type = elementCodec.encodedType) {
is JniType.Primitive -> type.array(context.symbols)
is JniType.Object -> type.array(context.symbols)
else -> error("Unsupported element type: $type")
is JniType.Real -> type.array(context.symbols)
is JniType.Void -> error("CollectionCodec<Void> is not supported.")
}
) {
/**
* The inner codec is the one that transforms the jobjectArray in a Collection<jobject>.
* We have two different implementations based on whether the encoded type is a primitive or not.
*/
private val runtimeHelperClassRaw: IrClass = when (val type = elementCodec.encodedType) {
is JniType.Void -> error("Void is not allowed here.")
is JniType.Primitive -> context.symbols.klass(PrimitiveCollectionCodec(type.knSimpleName)).owner
is JniType.Object -> context.symbols.klass(JObjectCollectionCodec).owner
else -> error("Not possible")
}

/**
Expand All @@ -93,11 +92,11 @@ class CollectionCodec constructor(

private fun IrBuilderWithScope.irGetOrCreateHelperRaw(): IrDeclarationReference {
return when (val type = elementCodec.encodedType) {
is JniType.Void -> error("Void is not allowed here.")
is JniType.Primitive -> irGetObject(runtimeHelperClassRaw.symbol)
is JniType.Object -> irCallConstructor(runtimeHelperClassRaw.primaryConstructor!!.symbol, emptyList()).apply {
putValueArgument(0, irString(type.jvm.jvmClassName))
}
else -> error("Should not happen")
}
}

Expand All @@ -118,6 +117,7 @@ class CollectionCodec constructor(
// Return type of this is symbols.klass(runtimeArraySpecClass)
// .typeWith(CollectionKind.Array.getCollectionType(elementCodec.localType, symbols), elementCodec.localType)
putValueArgument(1, when (val type = elementCodec.encodedType) {
is JniType.Void -> error("Void is not allowed here.")
is JniType.Primitive -> {
val name = PrimitiveArraySpec(type.jvmSimpleName)
irGetObject(this@CollectionCodec.context.symbols.klass(name))
Expand All @@ -128,7 +128,6 @@ class CollectionCodec constructor(
putTypeArgument(0, type.kn)
}
}
else -> error("Not possible")
})
// Constructor param: Source --> Transformed decoding lambda
putValueArgument(2, irLambda(
Expand Down Expand Up @@ -220,9 +219,9 @@ class CollectionCodec constructor(

override fun CodeBlock.Builder.codegenEncode(codegenContext: CodegenCodecContext, local: String): String {
val arrayName = when (val type = elementCodec.encodedType) {
is JniType.Void -> error("Void is not allowed here.")
is JniType.Primitive -> type.jvmSimpleName
is JniType.Object -> "Typed"
else -> error("Not possible")
}

return when (elementCodec.needsCodegenConversion) {
Expand Down
2 changes: 0 additions & 2 deletions knee-compiler-plugin/src/main/kotlin/codec/GenericCodec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class GenericCodec(

return when (wrappedType) {
is JniType.Object -> data // already a jobject
is JniType.Array -> data // already a jobject
is JniType.Long -> irEncodeBoxed("Long")
is JniType.Int -> irEncodeBoxed("Int")
is JniType.Double -> irEncodeBoxed("Double")
Expand All @@ -66,7 +65,6 @@ class GenericCodec(

val decoded = when (wrappedType) {
is JniType.Object -> irGet(jni) // irAs(irGet(jni), wrappedType.kn)
is JniType.Array -> irGet(jni) // irAs(irGet(jni), wrappedType.kn)
is JniType.Long -> irDecodeBoxed("Long")
is JniType.Int -> irDecodeBoxed("Int")
is JniType.Double -> irDecodeBoxed("Double")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ object UpwardFunctionsIr {
private val JniType.nameOfCallMethodFunction: String get() {
return when (this) {
is JniType.Void -> "Void"
is JniType.Object, is JniType.Array -> "Object"
is JniType.Object -> "Object"
is JniType.Int -> "Int"
is JniType.BooleanAsUByte -> "Boolean"
is JniType.Float -> "Float"
Expand Down
20 changes: 10 additions & 10 deletions knee-compiler-plugin/src/main/kotlin/jni/JniSignature.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ object JniSignature {
is JniType.Double -> append('D')

// OBJECT TYPES
is JniType.Object -> {
append('L')
append(type.jvm.jvmClassName) // cares about dollar signs
append(';')
}

// ARRAY TYPES
is JniType.Array -> {
append("[")
appendJniType(type.element, isReturnType)
is JniType.Object -> when (val arrayElement = type.arrayElement) {
null -> {
append('L')
append(type.jvm.jvmClassName) // cares about dollar signs
append(';')
}
else -> {
append("[")
appendJniType(arrayElement, isReturnType)
}
}
}
}
Expand Down
58 changes: 23 additions & 35 deletions knee-compiler-plugin/src/main/kotlin/jni/JniType.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ sealed interface JniType {
sealed interface Real : JniType {
val kn: IrSimpleType
val jvm: CodegenType
val jvmArray: CodegenType
fun array(symbols: KneeSymbols): Object = Object.array(symbols, jvmArray, this)
}

@Serializable
Expand All @@ -45,84 +47,70 @@ sealed interface JniType {
// E.g. for JniType.Boolean, jvm = "Boolean" and kn = "UByte"
val jvmSimpleName get() = jvm.name.simpleName
val knSimpleName get() = kn.let { CodegenType.from(it) }.name.simpleName
fun array(symbols: KneeSymbols): Array
}

@Serializable
class Int private constructor(@Contextual override val kn: IrSimpleType) : Primitive {
constructor(symbols: KneeSymbols) : this(kn = symbols.builtIns.intType as IrSimpleType)
override val jvm get() = CodegenType.from(INT)
override fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(INT_ARRAY), Int(symbols))
}
override val jvmArray get() = CodegenType.from(INT_ARRAY)
}

@Serializable
class Float private constructor(@Contextual override val kn: IrSimpleType) : Primitive {
constructor(symbols: KneeSymbols) : this(kn = symbols.builtIns.floatType as IrSimpleType)
override val jvm get() = CodegenType.from(FLOAT)
override fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(FLOAT_ARRAY), Float(symbols))
}
override val jvmArray get() = CodegenType.from(FLOAT_ARRAY)
}

@Serializable
class Double private constructor(@Contextual override val kn: IrSimpleType) : Primitive {
constructor(symbols: KneeSymbols) : this(kn = symbols.builtIns.doubleType as IrSimpleType)
override val jvm get() = CodegenType.from(DOUBLE)
override fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(DOUBLE_ARRAY), Double(symbols))
}
override val jvmArray get() = CodegenType.from(DOUBLE_ARRAY)
}

@Serializable
class Long private constructor(@Contextual override val kn: IrSimpleType) : Primitive {
constructor(symbols: KneeSymbols) : this(kn = symbols.builtIns.longType as IrSimpleType)
override val jvm get() = CodegenType.from(LONG)
override fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(LONG_ARRAY), Long(symbols))
}
override val jvmArray get() = CodegenType.from(LONG_ARRAY)
}

@Serializable
class Byte private constructor(@Contextual override val kn: IrSimpleType) : Primitive {
constructor(symbols: KneeSymbols) : this(kn = symbols.builtIns.byteType as IrSimpleType)
override val jvm get() = CodegenType.from(BYTE)
override fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(BYTE_ARRAY), Byte(symbols))
}
override val jvmArray get() = CodegenType.from(BYTE_ARRAY)
}

// The name makes it immediately clear that the types at the two ends are different
@Serializable
class BooleanAsUByte private constructor(@Contextual override val kn: IrSimpleType) : Primitive {
constructor(symbols: KneeSymbols) : this(kn = symbols.klass(KotlinIds.UByte).defaultType as IrSimpleType)
override val jvm get() = CodegenType.from(BOOLEAN)
override fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(BOOLEAN_ARRAY), BooleanAsUByte(symbols))
}
}

@Serializable
class Object private constructor(@Contextual override val kn: IrSimpleType, override val jvm: CodegenType) : Real {
constructor(symbols: KneeSymbols, jvm: CodegenType) : this(
kn = symbols.typeAliasUnwrapped(PlatformIds.jobject) as IrSimpleType,
jvm = jvm
)
fun array(symbols: KneeSymbols): Array {
return Array(symbols, CodegenType.from(ARRAY.parameterizedBy(jvm.name)), Object(symbols, jvm))
}
override val jvmArray get() = CodegenType.from(BOOLEAN_ARRAY)
}

// Can be array of primitive or array of object
@Serializable
class Array private constructor(
class Object private constructor(
@Contextual override val kn: IrSimpleType,
override val jvm: CodegenType,
val element: Real
val arrayElement: Real?
) : Real {
constructor(symbols: KneeSymbols, jvm: CodegenType, element: Real) : this(
symbols.typeAliasUnwrapped(PlatformIds.jobjectArray) as IrSimpleType, jvm, element
constructor(symbols: KneeSymbols, jvm: CodegenType) : this(
kn = symbols.typeAliasUnwrapped(PlatformIds.jobject) as IrSimpleType,
jvm = jvm,
arrayElement = null
)
// val isArray get() = arrayElement != null
override val jvmArray get() = CodegenType.from(ARRAY.parameterizedBy(jvm.name))
companion object {
fun array(symbols: KneeSymbols, jvm: CodegenType, element: Real) = Object(
kn = symbols.typeAliasUnwrapped(PlatformIds.jobjectArray) as IrSimpleType,
jvm = jvm,
arrayElement = element
)
}
}
}

0 comments on commit 1687f1d

Please sign in to comment.