Skip to content

Commit

Permalink
codegen: Improve enum support (#3861)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson authored Jun 26, 2024
1 parent c91e0cb commit 9032555
Show file tree
Hide file tree
Showing 31 changed files with 1,104 additions and 253 deletions.
1 change: 0 additions & 1 deletion doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ jsoniter "com.github.plokhotnyuk.jsoniter-scala" %% "jsoniter-scala
Currently, string-like enums in Scala 2 depend upon the enumeratum library (`"com.beachape" %% "enumeratum"`).
For Scala 3 we derive native enums, and depend on `"io.github.bishabosha" %% "enum-extensions"` for generating query
param serdes.
Other forms of OpenApi enum are not currently supported.

Models containing binary data cannot be re-used between json and multi-part form endpoints, due to having different
representation types for the binary data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ object BasicGenerator {
JsonSerdeLib.Circe
}

val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
Expand All @@ -59,7 +60,8 @@ object BasicGenerator {
jsonParamRefs = jsonParamRefs,
fullModelPath = s"$packagePath.$objName",
validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs,
maxSchemasPerFile = maxSchemasPerFile
maxSchemasPerFile = maxSchemasPerFile,
enumsDefinedOnEndpointParams = enumsDefinedOnEndpointParams
)
.getOrElse(GeneratedClassDefinitions("", None, Nil))
val hasJsonSerdes = jsonSerdes.nonEmpty
Expand Down Expand Up @@ -140,13 +142,50 @@ object BasicGenerator {
.mkString("\n")

val extraImports = if (endpointsInMain.nonEmpty) s"$maybeJsonImport$maybeSchemaImport" else ""
val queryParamSupport =
"""
|case class CommaSeparatedValues[T](values: List[T])
|case class ExplodedValues[T](values: List[T])
|trait QueryParamSupport[T] {
| def decode(s: String): sttp.tapir.DecodeResult[T]
| def encode(t: T): String
|}
|implicit def makeQueryCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(support.decode)(support.encode)
|}
|implicit def makeQueryOptCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
|}
|implicit def makeUnexplodedQuerySeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], CommaSeparatedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(values => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(s => CommaSeparatedValues(s.toList)))(_.values.map(support.encode).mkString(","))
|}
|implicit def makeUnexplodedQueryOptSeqCodecFromListHead[T](implicit support: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], Option[CommaSeparatedValues[T]], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode{
| case None => DecodeResult.Value(None)
| case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
| }(_.map(_.values.map(support.encode).mkString(",")))
|}
|implicit def makeExplodedQuerySeqCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.list[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(values => DecodeResult.sequence(values.map(support.decode)).map(s => ExplodedValues(s.toList)))(_.values.map(support.encode))
|}
|implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
|}
|""".stripMargin
val mainObj = s"""
|package $packagePath
|
|object $objName {
|
|${indent(2)(imports(normalisedJsonLib) + extraImports)}
|
|${indent(2)(queryParamSupport)}
|
|${indent(2)(classDefns)}
|
|${indent(2)(maybeSpecificationExtensionKeys)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class ClassDefinitionGenerator {
jsonParamRefs: Set[String] = Set.empty,
fullModelPath: String = "",
validateNonDiscriminatedOneOfs: Boolean = true,
maxSchemasPerFile: Int = 400
maxSchemasPerFile: Int = 400,
enumsDefinedOnEndpointParams: Boolean = false
): Option[GeneratedClassDefinitions] = {
val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap
val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq
val adtInheritanceMap: Map[String, Seq[String]] = mkMapParentsByChild(allOneOfSchemas)
val generatesQueryParamEnums =
val generatesQueryParamEnums = enumsDefinedOnEndpointParams ||
allSchemas
.collect { case (name, _: OpenapiSchemaEnum) => name }
.exists(queryParamRefs.contains)
Expand All @@ -49,14 +50,15 @@ class ClassDefinitionGenerator {
allTransitiveJsonParamRefs,
fullModelPath,
validateNonDiscriminatedOneOfs,
adtInheritanceMap
adtInheritanceMap,
targetScala3
)
val defns = doc.components
.map(_.schemas.flatMap {
case (name, obj: OpenapiSchemaObject) =>
generateClass(allSchemas, name, obj, allTransitiveJsonParamRefs, adtInheritanceMap)
generateClass(allSchemas, name, obj, allTransitiveJsonParamRefs, adtInheritanceMap, jsonSerdeLib, targetScala3)
case (name, obj: OpenapiSchemaEnum) =>
generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
EnumGenerator.generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (_, _: OpenapiSchemaOneOf) => Nil
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
Expand Down Expand Up @@ -95,50 +97,55 @@ class ClassDefinitionGenerator {
.groupBy(_._1)
.mapValues(_.map(_._2))

private def enumQuerySerdeHelperDefn(targetScala3: Boolean): String = if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|
|def makeQueryCodecForEnum[T: enumextensions.EnumMirror]: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec
| .listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util
| .Try(enumMap[T](using enumextensions.EnumMirror[T])(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumextensions.EnumMirror[T].mirroredName}, available values: ${enumextensions.EnumMirror[T].values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.name)
|""".stripMargin
else
"""def makeQueryCodecForEnum[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.entryName)
|""".stripMargin
private def enumQuerySerdeHelperDefn(targetScala3: Boolean): String = {
if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|case class EnumQueryParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends QueryParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util
| .Try(eMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumextensions.EnumMirror[T].mirroredName}, available values: ${enumextensions.EnumMirror[T].values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| def encode(t: T): String = t.name
|}
|def queryCodecSupport[T: enumextensions.EnumMirror]: QueryParamSupport[T] =
| EnumQueryParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|""".stripMargin
else
"""
|case class EnumQueryParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends QueryParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| def encode(t: T): String = t.entryName
|}
|def queryCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): QueryParamSupport[T] =
| EnumQueryParamSupport(enumName, T)
|""".stripMargin
}

@tailrec
final def recursiveFindAllReferencedSchemaTypes(
Expand Down Expand Up @@ -191,63 +198,14 @@ class ClassDefinitionGenerator {
Seq(s"""type $name = Map[String, $valueSchemaName]""")
}

// Uses enumeratum for scala 2, but generates scala 3 enums instead where it can
private[codegen] def generateEnum(
name: String,
obj: OpenapiSchemaEnum,
targetScala3: Boolean,
queryParamRefs: Set[String],
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
jsonParamRefs: Set[String]
): Seq[String] = if (targetScala3) {
val maybeCompanion =
if (queryParamRefs contains name)
s"""
|object $name {
| given stringList${name}Codec: sttp.tapir.Codec[List[String], $name, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum[$name]
|}""".stripMargin
else ""
val maybeCodecExtensions = jsonSerdeLib match {
case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => ""
case _ if !jsonParamRefs.contains(name) => " derives enumextensions.EnumMirror"
case JsonSerdeLib.Circe if !queryParamRefs.contains(name) => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec"
case JsonSerdeLib.Circe => " derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror"
case JsonSerdeLib.Jsoniter if !queryParamRefs.contains(name) => s" extends java.lang.Enum[$name]"
case JsonSerdeLib.Jsoniter => s" extends java.lang.Enum[$name] derives enumextensions.EnumMirror"
}
s"""$maybeCompanion
|enum $name$maybeCodecExtensions {
| case ${obj.items.map(_.value).mkString(", ")}
|}""".stripMargin :: Nil
} else {
val uncapitalisedName = BasicGenerator.uncapitalise(name)
val members = obj.items.map { i => s"case object ${i.value} extends $name" }
val maybeCodecExtension = jsonSerdeLib match {
case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => ""
case JsonSerdeLib.Circe => s" with enumeratum.CirceEnum[$name]"
case JsonSerdeLib.Jsoniter => ""
}
val maybeQueryCodecDefn =
if (queryParamRefs contains name)
s"""
| implicit val ${uncapitalisedName}QueryCodec: sttp.tapir.Codec[List[String], ${name}, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum("${name}", ${name})""".stripMargin
else ""
s"""
|sealed trait $name extends enumeratum.EnumEntry
|object $name extends enumeratum.Enum[$name]$maybeCodecExtension {
| val values = findValues
|${indent(2)(members.mkString("\n"))}$maybeQueryCodecDefn
|}""".stripMargin :: Nil
}

private[codegen] def generateClass(
allSchemas: Map[String, OpenapiSchemaType],
name: String,
obj: OpenapiSchemaObject,
jsonParamRefs: Set[String],
adtInheritanceMap: Map[String, Seq[String]]
adtInheritanceMap: Map[String, Seq[String]],
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
targetScala3: Boolean
): Seq[String] = {
val isJson = jsonParamRefs contains name
def rec(name: String, obj: OpenapiSchemaObject, acc: List[String]): Seq[String] = {
Expand All @@ -268,24 +226,25 @@ class ClassDefinitionGenerator {
.flatten
.toList

val properties = obj.properties.map { case (key, OpenapiSchemaField(schemaType, maybeDefault)) =>
val tpe = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson)
val (properties, maybeEnums) = obj.properties.map { case (key, OpenapiSchemaField(schemaType, maybeDefault)) =>
val (tpe, maybeEnum) = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson, jsonSerdeLib, targetScala3)
val fixedKey = fixKey(key)
val optional = schemaType.nullable || !obj.required.contains(key)
val maybeExplicitDefault =
maybeDefault.map(" = " + DefaultValueRenderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val default = maybeExplicitDefault getOrElse (if (optional) " = None" else "")
s"$fixedKey: $tpe$default"
}
s"$fixedKey: $tpe$default" -> maybeEnum
}.unzip

val parents = adtInheritanceMap.getOrElse(name, Nil) match {
case Nil => ""
case ps => ps.mkString(" extends ", " with ", "")
}

val enumDefn = maybeEnums.flatten.toList
s"""|case class $name (
|${indent(2)(properties.mkString(",\n"))}
|)$parents""".stripMargin :: innerClasses ::: acc
|)$parents""".stripMargin :: innerClasses ::: enumDefn ::: acc
}

rec(addName("", name), obj, Nil)
Expand All @@ -296,28 +255,52 @@ class ClassDefinitionGenerator {
key: String,
required: Boolean,
schemaType: OpenapiSchemaType,
isJson: Boolean
): String = {
val (tpe, optional) = schemaType match {
isJson: Boolean,
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
targetScala3: Boolean
): (String, Option[String]) = {
val ((tpe, optional), maybeEnum) = schemaType match {
case simpleType: OpenapiSchemaSimpleType =>
mapSchemaSimpleTypeToType(simpleType, multipartForm = !isJson)
mapSchemaSimpleTypeToType(simpleType, multipartForm = !isJson) -> None

case objectType: OpenapiSchemaObject =>
addName(parentName, key) -> objectType.nullable
(addName(parentName, key) -> objectType.nullable, None)

case mapType: OpenapiSchemaMap =>
val innerType = mapSchemaTypeToType(addName(parentName, key), "item", required = true, mapType.items, isJson = isJson)
s"Map[String, $innerType]" -> mapType.nullable
val (innerType, maybeEnum) =
mapSchemaTypeToType(addName(parentName, key), "item", required = true, mapType.items, isJson = isJson, jsonSerdeLib, targetScala3)
(s"Map[String, $innerType]" -> mapType.nullable, maybeEnum)

case arrayType: OpenapiSchemaArray =>
val innerType = mapSchemaTypeToType(addName(parentName, key), "item", required = true, arrayType.items, isJson = isJson)
s"Seq[$innerType]" -> arrayType.nullable
val (innerType, maybeEnum) =
mapSchemaTypeToType(
addName(parentName, key),
"item",
required = true,
arrayType.items,
isJson = isJson,
jsonSerdeLib,
targetScala3
)
(s"Seq[$innerType]" -> arrayType.nullable, maybeEnum)

case e: OpenapiSchemaEnum =>
val enumName = addName(parentName.capitalize, key)
val enumDefn = EnumGenerator.generateEnum(
enumName,
e,
targetScala3,
Set.empty,
jsonSerdeLib,
if (isJson) Set(enumName) else Set.empty
)
(enumName -> e.nullable, Some(enumDefn.mkString("\n")))

case _ =>
throw new NotImplementedError(s"We can't serialize some of the properties yet! $parentName $key $schemaType")
}

if (optional || !required) s"Option[$tpe]" else tpe
(if (optional || !required) s"Option[$tpe]" else tpe, maybeEnum)
}

private def addName(parentName: String, key: String) = parentName + key.replace('_', ' ').replace('-', ' ').capitalize.replace(" ", "")
Expand Down
Loading

0 comments on commit 9032555

Please sign in to comment.