Skip to content

Commit

Permalink
codegen: support enum query params (#3602)
Browse files Browse the repository at this point in the history
Co-authored-by: kciesielski <[email protected]>
  • Loading branch information
hughsimpson and kciesielski authored Mar 15, 2024
1 parent d6e9cad commit 633810b
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 38 deletions.
2 changes: 1 addition & 1 deletion doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ having no tags, would be output to the `TapirGeneratedEndpoints` file, along wit
Currently, the generated code depends on `"io.circe" %% "circe-generic"`. In the future probably we will make the encoder/decoder json lib configurable (PRs welcome).

String-like enums in Scala 2 depend on both `"com.beachape" %% "enumeratum"` and `"com.beachape" %% "enumeratum-circe"`.
For Scala 3 we derive native enums, and depend instead on `"org.latestbit" %% "circe-tagged-adt-codec"`.
For Scala 3 we derive native enums, and depend on `"org.latestbit" %% "circe-tagged-adt-codec"` for json serdes and `"io.github.bishabosha" %% "enum-extensions"` for query param serdes.
Other forms of OpenApi enum are not currently supported.

We currently miss a lot of OpenApi features like:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaBinary,
OpenapiSchemaDateTime,
OpenapiSchemaDouble,
OpenapiSchemaEnum,
OpenapiSchemaFloat,
OpenapiSchemaInt,
OpenapiSchemaLong,
Expand All @@ -29,11 +28,7 @@ object BasicGenerator {
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean
): Map[String, String] = {
val enumImport =
if (!targetScala3 && doc.components.toSeq.flatMap(_.schemas).exists(_._2.isInstanceOf[OpenapiSchemaEnum])) "\n import enumeratum._"
else ""

val endpointsByTag = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val EndpointDefs(endpointsByTag, queryParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val taggedObjs = endpointsByTag.collect {
case (Some(headTag), body) if body.nonEmpty =>
val taggedObj =
Expand All @@ -55,9 +50,9 @@ object BasicGenerator {
|
|object $objName {
|
|${indent(2)(imports)}$enumImport
|${indent(2)(imports)}
|
|${indent(2)(classGenerator.classDefs(doc, targetScala3).getOrElse(""))}
|${indent(2)(classGenerator.classDefs(doc, targetScala3, queryParamRefs).getOrElse(""))}
|
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaConstantString,
OpenapiSchemaEnum,
OpenapiSchemaMap,
OpenapiSchemaObject,
Expand All @@ -14,17 +13,68 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{

class ClassDefinitionGenerator {

def classDefs(doc: OpenapiDocument, targetScala3: Boolean = false): Option[String] = {
doc.components
def classDefs(doc: OpenapiDocument, targetScala3: Boolean = false, queryParamRefs: Set[String] = Set.empty): Option[String] = {
val generatesQueryParamEnums =
doc.components.toSeq
.flatMap(_.schemas.collect { case (name, _: OpenapiSchemaEnum) => name })
.exists(queryParamRefs.contains)
val enumQuerySerdeHelper =
if (!generatesQueryParamEnums) ""
else if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|
|def makeQueryCodecForEnum[T: enumextensions.EnumMirror]: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec
| .listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util
| .Try(enumMap[T](using enumextensions.EnumMirror[T])(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumextensions.EnumMirror[T].mirroredName}, available values: ${enumextensions.EnumMirror[T].values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.name)
|""".stripMargin
else
"""def makeQueryCodecForEnum[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.entryName)
|""".stripMargin
val defns = doc.components
.map(_.schemas.flatMap {
case (name, obj: OpenapiSchemaObject) =>
generateClass(name, obj)
case (name, obj: OpenapiSchemaEnum) =>
generateEnum(name, obj, targetScala3)
generateEnum(name, obj, targetScala3, queryParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
})
.map(_.mkString("\n"))
defns.map(enumQuerySerdeHelper + _)
}

private[codegen] def generateMap(name: String, valueSchema: OpenapiSchemaType): Seq[String] = {
Expand All @@ -36,16 +86,37 @@ class ClassDefinitionGenerator {
}

// Uses enumeratum for scala 2, but generates scala 3 enums instead where it can
private[codegen] def generateEnum(name: String, obj: OpenapiSchemaEnum, targetScala3: Boolean): Seq[String] = if (targetScala3) {
s"""enum $name derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec {
private[codegen] def generateEnum(
name: String,
obj: OpenapiSchemaEnum,
targetScala3: Boolean,
queryParamRefs: Set[String]
): Seq[String] = if (targetScala3) {
val maybeQueryParamSerdeDerivation = if (queryParamRefs contains name) ", enumextensions.EnumMirror" else ""
val maybeCompanion =
if (queryParamRefs contains name)
s"""
|object $name {
| given stringList${name}Codec: sttp.tapir.Codec[List[String], $name, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum[$name]
|}""".stripMargin
else ""
s"""$maybeCompanion
|enum $name derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec$maybeQueryParamSerdeDerivation {
| case ${obj.items.map(_.value).mkString(", ")}
|}""".stripMargin :: Nil
} else {
val members = obj.items.map { i => s"case object ${i.value} extends $name" }
s"""|sealed trait $name extends EnumEntry
|object $name extends Enum[$name] with CirceEnum[$name] {
val maybeQueryCodecDefn =
if (queryParamRefs contains name)
s"""
| implicit val ${name.head.toLower +: name.tail}Codec: sttp.tapir.Codec[List[String], ${name}, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum("${name}", ${name})""".stripMargin
else ""
s"""|sealed trait $name extends enumeratum.EnumEntry
|object $name extends enumeratum.Enum[$name] with enumeratum.CirceEnum[$name] {
| val values = findValues
|${indent(2)(members.mkString("\n"))}
|${indent(2)(members.mkString("\n"))}$maybeQueryCodecDefn
|}""".stripMargin :: Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,25 @@ case class Location(path: String, method: String) {
override def toString: String = s"${method.toUpperCase} ${path}"
}

case class GeneratedEndpoints(namesAndBodies: Seq[(Option[String], Seq[(String, String)])], queryParamRefs: Set[String]) {
def merge(that: GeneratedEndpoints): GeneratedEndpoints =
GeneratedEndpoints(
(namesAndBodies ++ that.namesAndBodies).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _)).toSeq,
queryParamRefs ++ that.queryParamRefs
)
}
case class EndpointDefs(endpointDecls: Map[Option[String], String], queryParamRefs: Set[String])

class EndpointGenerator {
private def bail(msg: String)(implicit location: Location): Nothing = throw new NotImplementedError(s"$msg at $location")

private[codegen] def allEndpoints: String = "generatedEndpoints"

def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): Map[Option[String], String] = {
def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): EndpointDefs = {
val components = Option(doc.components).flatten
val geMap =
doc.paths.flatMap(generatedEndpoints(components, useHeadTagForObjectNames)).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _))
geMap.mapValues { ge =>
val GeneratedEndpoints(geMap, queryParamRefs) =
doc.paths.map(generatedEndpoints(components, useHeadTagForObjectNames)).foldLeft(GeneratedEndpoints(Nil, Set.empty))(_ merge _)
val endpointDecls = geMap.map { case (k, ge) =>
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
Expand All @@ -33,20 +42,21 @@ class EndpointGenerator {
.mkString("\n")
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
k -> s"""|$definitions
|
|$allEP
|""".stripMargin
}.toMap
EndpointDefs(endpointDecls, queryParamRefs)
}

private[codegen] def generatedEndpoints(components: Option[OpenapiComponent], useHeadTagForObjectNames: Boolean)(
p: OpenapiPath
): Seq[(Option[String], Seq[(String, String)])] = {
): GeneratedEndpoints = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)

p.methods
val (fileNamesAndParams, unflattenedQueryParamRefs) = p.methods
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)
Expand All @@ -68,11 +78,18 @@ class EndpointGenerator {
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
(maybeTargetFileName, (name, definition))
val queryParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema }
.collect { case OpenapiSchemaRef(ref) if ref.startsWith("#/components/schemas/") => ref.stripPrefix("#/components/schemas/") }
.toSet
(maybeTargetFileName, (name, definition)) -> queryParamRefs
}
.unzip
val namesAndParamsByFile = fileNamesAndParams
.groupBy(_._1)
.toSeq
.map { case (maybeTargetFileName, defns) => maybeTargetFileName -> defns.map(_._2) }
GeneratedEndpoints(namesAndParamsByFile, unflattenedQueryParamRefs.foldLeft(Set.empty[String])(_ ++ _))
}

private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,14 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
(schemas + "\n" + (endpoints.linesIterator.filterNot(_ startsWith "package").mkString("\n"))) shouldCompile ()
}

it should "compile endpoints with enum query params" in {
BasicGenerator.generateObjects(
TestHelpers.enumQueryParamDocs,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints") shouldCompile ()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
)
)
// the enumeratum import should be included by the BasicGenerator iff we generated enums
"import enumeratum._;" + (new ClassDefinitionGenerator().classDefs(doc).get) shouldCompile ()
new ClassDefinitionGenerator().classDefs(doc).get shouldCompile ()
}

it should "generate simple class with reserved propName" in {
Expand Down Expand Up @@ -280,10 +280,44 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {

val gen = new ClassDefinitionGenerator()
val res = gen.classDefs(doc, true)
// can't just check whether this compiles, because our tests only run on scala 2.12 - so instead just eyeball it...
res shouldBe Some("""enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec {
| case enum1, enum2
|}""".stripMargin)
val resWithQueryParamCodec = gen.classDefs(doc, true, queryParamRefs = Set("Test"))
// can't just check whether these compile, because our tests only run on scala 2.12 - so instead just eyeball it...
res shouldBe Some("""
|enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec {
| case enum1, enum2
|}""".stripMargin)
resWithQueryParamCodec shouldBe Some("""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|
|def makeQueryCodecForEnum[T: enumextensions.EnumMirror]: sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] =
| sttp.tapir.Codec
| .listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(s =>
| // Case-insensitive mapping
| scala.util
| .Try(enumMap[T](using enumextensions.EnumMirror[T])(s.toUpperCase))
| .fold(
| _ =>
| sttp.tapir.DecodeResult.Error(
| s,
| new NoSuchElementException(
| s"Could not find value $s for enum ${enumextensions.EnumMirror[T].mirroredName}, available values: ${enumextensions.EnumMirror[T].values.mkString(", ")}"
| )
| ),
| sttp.tapir.DecodeResult.Value(_)
| )
| )(_.name)
|
|object Test {
| given stringListTestCodec: sttp.tapir.Codec[List[String], Test, sttp.tapir.CodecFormat.TextPlain] =
| makeQueryCodecForEnum[Test]
|}
|enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror {
| case enum1, enum2
|}""".stripMargin)
}

it should "generate named maps" in {
Expand All @@ -305,8 +339,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
)

val gen = new ClassDefinitionGenerator()
val res = gen.classDefs(doc, false)
"import enumeratum._;" + res.get shouldCompile ()
gen.classDefs(doc, false).get shouldCompile ()
}

import cats.implicits._
Expand Down Expand Up @@ -344,7 +377,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {

val res: String = parserRes match {
case Left(value) => throw new Exception(value)
case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None)
}

val compileUnit =
Expand All @@ -355,7 +388,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
| $res
|}
| """.stripMargin
println(compileUnit)

compileUnit shouldCompile ()

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
val generatedCode =
BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None)
generatedCode should include("val getTestAsdId =")
generatedCode shouldCompile ()
}
Expand Down Expand Up @@ -131,7 +132,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
)
BasicGenerator.imports ++
new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) shouldCompile ()
new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None) shouldCompile ()
}

it should "handle status codes" in {
Expand Down Expand Up @@ -174,7 +175,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
val generatedCode =
BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false).endpointDecls(None)
generatedCode should include(
""".out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))"""
) // status code with body
Expand Down
Loading

0 comments on commit 633810b

Please sign in to comment.