Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.19] Address caching issue in schema compiler that can lead to OOMs #1329

Open
wants to merge 5 commits into
base: series/0.19
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.19.0

* Reworked the CachedSchemaCompiler construct to prevent memory leaks when using schema-based implicit derivation in dynamic methods.

# 0.18.4

* Changes the behaviour of `Field#getUnlessDefault` and `Field#foreachUnlessDefault` to always take the value into consideration when the `smithy.api#required` trait
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2021-2023 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package smithy4s

import munit.FunSuite
import smithy4s.schema._

class CachedSchemaCompilerSpec() extends FunSuite {

test(
"CachedSchemaCompiler.Impl memoizes the result of `fromSchemaAux`"
) {
var x = 0
val compiler = new CachedSchemaCompiler.Impl[Option] {
def fromSchemaAux[A](schema: Schema[A], cache: AuxCache): Option[A] = {
x += 1
None
}
}
val cache = compiler.createCache()
discardResult(compiler.fromSchema(Schema.int, cache))
discardResult(compiler.fromSchema(Schema.int, cache))
assertEquals(x, 1)
}

test(
"CachedSchemaCompiler.Impl memoization is stable through mapK"
) {
var x = 0
val transformation = new smithy4s.kinds.PolyFunction[Option, Option] {
def apply[A](fa: Option[A]) = {
x += 1
fa
}
}
val compiler = new CachedSchemaCompiler.Impl[Option] {
def fromSchemaAux[A](schema: Schema[A], cache: AuxCache): Option[A] = {
None
}
}.mapK(transformation)
val cache = compiler.createCache()
discardResult(compiler.fromSchema(Schema.int, cache))
discardResult(compiler.fromSchema(Schema.int, cache))
assertEquals(x, 1)
}

test(
"CachedSchemaCompiler.Impl memoization is stable through contramapSchema"
) {
var x = 0
val transformation = new smithy4s.kinds.PolyFunction[Schema, Schema] {
def apply[A](fa: Schema[A]) = {
x += 1
fa
}
}
val compiler = new CachedSchemaCompiler.Impl[Option] {
def fromSchemaAux[A](schema: Schema[A], cache: AuxCache): Option[A] = {
None
}
}.contramapSchema(transformation)
val cache = compiler.createCache()
discardResult(compiler.fromSchema(Schema.int, cache))
discardResult(compiler.fromSchema(Schema.int, cache))
assertEquals(x, 1)
}

private def discardResult[A](f: => A): Unit = {
val _ = f
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ class StringAndBlobSpec() extends munit.FunSuite {
val error = PayloadError(PayloadPath.root, "error", "error")
object DummyDecoderCompiler
extends CachedSchemaCompiler.Impl[PayloadDecoder] {
def fromSchema[A](schema: Schema[A], cache: Cache): PayloadDecoder[A] =
def fromSchemaAux[A](
schema: Schema[A],
cache: AuxCache
): PayloadDecoder[A] =
Decoder.static(Left(error): Either[PayloadError, A])
}

object DummyWriterCompiler extends CachedSchemaCompiler.Impl[PayloadEncoder] {
def fromSchema[A](schema: Schema[A], cache: Cache): PayloadEncoder[A] =
def fromSchemaAux[A](
schema: Schema[A],
cache: AuxCache
): PayloadEncoder[A] =
Encoder.static(Blob.empty): Encoder[Blob, A]

}
Expand Down
4 changes: 2 additions & 2 deletions modules/cats/src/smithy4s/interopcats/SchemaVisitorHash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import scala.util.hashing.MurmurHash3.productSeed

object SchemaVisitorHash extends CachedSchemaCompiler.Impl[Hash] {
protected type Aux[A] = Hash[A]
def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Hash[A] = {
schema.compile(new SchemaVisitorHash(cache))
}
Expand Down
4 changes: 2 additions & 2 deletions modules/cats/src/smithy4s/interopcats/SchemaVisitorShow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import smithy4s.schema.Alt.Precompiler

object SchemaVisitorShow extends CachedSchemaCompiler.Impl[Show] {
protected type Aux[A] = Show[A]
def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Show[A] = {
schema.compile(new SchemaVisitorShow(cache))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ object CanonicalSmithyDecoder {

protected type Aux[A] = smithy4s.internals.DocumentDecoder[A]

def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Decoder[A] = {
val decodeFunction =
schema.compile(new SmithyNodeDocumentDecoderSchemaVisitor(cache))
Expand Down
8 changes: 4 additions & 4 deletions modules/core/src/smithy4s/Document.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ object Document {

protected type Aux[A] = internals.DocumentEncoder[A]

def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Encoder[A] = {
val makeEncoder =
schema.compile(new DocumentEncoderSchemaVisitor(cache))
Expand All @@ -127,9 +127,9 @@ object Document {

protected type Aux[A] = internals.DocumentDecoder[A]

def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Decoder[A] = {
val decodeFunction =
schema.compile(new DocumentDecoderSchemaVisitor(cache))
Expand Down
8 changes: 4 additions & 4 deletions modules/core/src/smithy4s/codecs/StringAndBlobCodecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ import smithy4s.capability.instances.either._
object StringAndBlobCodecs {

object decoders extends CachedSchemaCompiler.Optional.Impl[BlobDecoder] {
def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Option[BlobDecoder[A]] =
StringAndBlobReaderVisitor(schema)
}

object encoders extends CachedSchemaCompiler.Optional.Impl[BlobEncoder] {
def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Option[BlobEncoder[A]] =
StringAndBlobWriterVisitor(schema)
}
Expand Down
5 changes: 4 additions & 1 deletion modules/core/src/smithy4s/http/HttpStatusCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ object HttpStatusCode extends CachedSchemaCompiler.Impl[HttpStatusCode] {
instance
type Aux[A] = internals.HttpCode[A]

def fromSchema[A](schema: Schema[A], cache: Cache): HttpStatusCode[A] = {
def fromSchemaAux[A](
schema: Schema[A],
cache: AuxCache
): HttpStatusCode[A] = {
val visitor = new internals.ErrorCodeSchemaVisitor(cache)
val go = schema.compile(visitor)
new HttpStatusCode[A] {
Expand Down
9 changes: 4 additions & 5 deletions modules/core/src/smithy4s/http/Metadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import smithy4s.http.internals.MetaEncode._
import smithy4s.http.internals.SchemaVisitorMetadataReader
import smithy4s.http.internals.SchemaVisitorMetadataWriter
import smithy4s.schema.CachedSchemaCompiler
import smithy4s.schema.CompilationCache

/**
* Datatype containing metadata associated to a http message.
Expand Down Expand Up @@ -185,9 +184,9 @@ object Metadata {
def apply[A](implicit instance: Decoder[A]): Decoder[A] =
instance

def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: CompilationCache[internals.MetaDecode]
cache: AuxCache
): Decoder[A] = {
val metaDecode =
new SchemaVisitorMetadataReader(cache, awsHeaderEncoding)(schema)
Expand Down Expand Up @@ -224,9 +223,9 @@ object Metadata {

def apply[A](implicit instance: Encoder[A]): Encoder[A] = instance

def fromSchema[A](
def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Encoder[A] = {
val toStatusCode: A => Option[Int] = { a =>
schema.compile(new HttpResponseCodeSchemaVisitor()) match {
Expand Down
8 changes: 4 additions & 4 deletions modules/core/src/smithy4s/http/UrlForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ object UrlForm {
): CachedSchemaCompiler[Decoder] =
new CachedSchemaCompiler.Impl[Decoder] {
protected override type Aux[A] = UrlFormDataDecoder[A]
override def fromSchema[A](
override def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Decoder[A] = {
val schemaVisitor = new UrlFormDataDecoderSchemaVisitor(
cache,
Expand All @@ -201,9 +201,9 @@ object UrlForm {
): CachedSchemaCompiler[Encoder] =
new CachedSchemaCompiler.Impl[Encoder] {
protected override type Aux[A] = UrlFormDataEncoder[A]
override def fromSchema[A](
override def fromSchemaAux[A](
schema: Schema[A],
cache: Cache
cache: AuxCache
): Encoder[A] = {
val maybeStaticUrlFormData =
schema.hints.get(internals.StaticUrlFormElements).map {
Expand Down
49 changes: 42 additions & 7 deletions modules/core/src/smithy4s/schema/CachedSchemaCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ trait CachedSchemaCompiler[+F[_]] { self =>
def fromSchema[A](schema: Schema[A]): F[A]
def fromSchema[A](schema: Schema[A], cache: Cache): F[A]

final def mapK[F0[x] >: F[x], G[_]](
def mapK[F0[x] >: F[x], G[_]](
fk: PolyFunction[F0, G]
): CachedSchemaCompiler[G] =
new CachedSchemaCompiler[G] {
Expand All @@ -39,7 +39,7 @@ trait CachedSchemaCompiler[+F[_]] { self =>
)
}

final def contramapSchema(
def contramapSchema(
fk: PolyFunction[Schema, Schema]
): CachedSchemaCompiler[F] = new CachedSchemaCompiler[F] {
type Cache = self.Cache
Expand All @@ -49,7 +49,6 @@ trait CachedSchemaCompiler[+F[_]] { self =>

def fromSchema[A](schema: Schema[A], cache: Cache): F[A] =
self.fromSchema(fk(schema), cache)

}

}
Expand Down Expand Up @@ -86,14 +85,50 @@ object CachedSchemaCompiler { outer =>
self.mapK(fk)
}

abstract class Impl[F[_]] extends CachedSchemaCompiler[F] {
abstract class Impl[F[_]] extends CachedSchemaCompiler[F] { self =>
protected type Aux[_]
type Cache = CompilationCache[Aux]
type AuxCache = CompilationCache[Aux]
case class Cache(outer: CompilationCache[F], inner: AuxCache)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the core of the change : we need an inner cache to be passed to some schema visitor, but we also need an outer cache to prevent schema-preprocessing from being re-applied

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: is it only a problem because the preprocessing doesn't produce deterministic (as per hashCode/equals) outputs?

Do we know which transformation was causing the particular issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's 2 problems : most pernicious one is what you describe, and it's the total function held by enumeration schemas which is at fault. Hence the changes in my other PR.

Less pernicious one is that the it's the input of the schema visitor call that gets cached instead of the input of the schema compiler.

This means that we're not protecting against re-running the inefficient pre-processing of schemas that may occur (like hint masks), which is really bad performance wise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good info, thanks


def fromSchemaAux[A](
schema: Schema[A],
innerCache: CompilationCache[Aux]
): F[A]

override final def fromSchema[A](schema: Schema[A]): F[A] =
fromSchema(schema, CompilationCache.nop[Aux])
fromSchema(
schema,
Cache(CompilationCache.nop[F], CompilationCache.nop[Aux])
)

override final def fromSchema[A](schema: Schema[A], cache: Cache) = {
cache.outer.getOrElseUpdate(schema, s => fromSchemaAux(s, cache.inner))
}

override final def mapK[F0[x] >: F[x], G[_]](
fk: PolyFunction[F0, G]
): CachedSchemaCompiler[G] = new Impl[G] {
type Aux[A] = self.Aux[A]
def fromSchemaAux[A](
schema: Schema[A],
innerCache: CompilationCache[Aux]
): G[A] =
fk(self.fromSchemaAux(schema, innerCache))
}

override final def contramapSchema(
fk: PolyFunction[Schema, Schema]
): CachedSchemaCompiler[F] = new Impl[F] {
type Aux[A] = self.Aux[A]
def fromSchemaAux[A](
schema: Schema[A],
innerCache: CompilationCache[Aux]
): F[A] =
self.fromSchemaAux(fk(schema), innerCache)
}

def createCache(): Cache = CompilationCache.make[Aux]
final def createCache(): Cache =
Cache(CompilationCache.make[F], CompilationCache.make[Aux])
}

abstract class Uncached[F[_]] extends CachedSchemaCompiler[F] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ private[smithy4s] case class JsoniterCodecCompilerImpl(
): JsoniterCodecCompiler =
copy(preserveMapOrder = preserveMapOrder)

def fromSchema[A](schema: Schema[A], cache: Cache): JCodec[A] = {
def fromSchemaAux[A](schema: Schema[A], cache: AuxCache): JCodec[A] = {
val visitor = new SchemaVisitorJCodec(
maxArity,
explicitDefaultsEncoding,
Expand Down
4 changes: 2 additions & 2 deletions modules/xml/src/smithy4s/xml/XmlDocument.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ object XmlDocument {
implicit def decoderFromSchema[A: Schema]: Decoder[A] = Decoder.derivedImplicitInstance
object Decoder extends CachedSchemaCompiler.DerivingImpl[Decoder] {
protected override type Aux[A] = XmlDecoder[A]
def fromSchema[A](schema: Schema[A], cache: Cache): Decoder[A] = {
def fromSchemaAux[A](schema: Schema[A], cache: AuxCache): Decoder[A] = {
val startingPath: List[XmlQName] = getStartingPath(schema)
val schemaVisitor = new XmlDecoderSchemaVisitor(cache)
val xmlDecoder = schemaVisitor(schema)
Expand All @@ -126,7 +126,7 @@ object XmlDocument {
}
object Encoder extends CachedSchemaCompiler.Impl[Encoder] {
protected override type Aux[A] = XmlEncoder[A]
def fromSchema[A](schema: Schema[A], cache: Cache): Encoder[A] = {
def fromSchemaAux[A](schema: Schema[A], cache: AuxCache): Encoder[A] = {
val rootName: XmlQName = getRootName(schema)
val rootNamespace =
schema.hints
Expand Down