diff --git a/build.sbt b/build.sbt index 96f984637..b65e81277 100644 --- a/build.sbt +++ b/build.sbt @@ -767,6 +767,7 @@ lazy val complianceTests = projectMatrix else Seq.empty ce3 ++ Seq( Dependencies.Circe.parser.value, + Dependencies.Smithy.utils , Dependencies.Http4s.circe.value, Dependencies.Http4s.client.value, Dependencies.Weaver.cats.value % Test, diff --git a/modules/compliance-tests/src/smithy4s/compliancetests/ComplianceTest.scala b/modules/compliance-tests/src/smithy4s/compliancetests/ComplianceTest.scala index 6efa93a80..928b2edab 100644 --- a/modules/compliance-tests/src/smithy4s/compliancetests/ComplianceTest.scala +++ b/modules/compliance-tests/src/smithy4s/compliancetests/ComplianceTest.scala @@ -17,9 +17,32 @@ package smithy4s.compliancetests import ComplianceTest.ComplianceResult +import smithy.test.NonEmptyString +import smithy4s.ShapeId -case class ComplianceTest[F[_]](name: String, run: F[ComplianceResult]) +case class ComplianceTest[F[_]]( + id: String, + endpoint: ShapeId, + tags: List[String], + run: F[ComplianceResult] +) { + private val showTags = + if (tags.isEmpty) "" else tags.mkString(" Tags[", ", ", "]") + def show = s"${endpoint.id}: $id $showTags" +} object ComplianceTest { type ComplianceResult = Either[String, Unit] + def apply[F[_]]( + id: String, + endpoint: ShapeId, + tags: Option[List[NonEmptyString]], + run: F[ComplianceResult] + ): ComplianceTest[F] = + ComplianceTest( + id, + endpoint, + tags.getOrElse(List.empty).map(_.value), + run + ) } diff --git a/modules/compliance-tests/src/smithy4s/compliancetests/HttpProtocolCompliance.scala b/modules/compliance-tests/src/smithy4s/compliancetests/HttpProtocolCompliance.scala index 6627742c0..939252f2d 100644 --- a/modules/compliance-tests/src/smithy4s/compliancetests/HttpProtocolCompliance.scala +++ b/modules/compliance-tests/src/smithy4s/compliancetests/HttpProtocolCompliance.scala @@ -46,10 +46,28 @@ object HttpProtocolCompliance { service ).allServerTests() + def malformedRequestTests[F[_], Alg[_[_, _, _, _, _]]]( + impl: Router[F], + service: Service[Alg] + )(implicit ce: CompatEffect[F]): List[ComplianceTest[F]] = + new internals.MalformedRequestComplianceTestCase[F, Alg]( + impl, + service + ).malformedRequestTests() + def clientAndServerTests[F[_], Alg[_[_, _, _, _, _]]]( router: Router[F] with ReverseRouter[F], service: Service[Alg] )(implicit ce: CompatEffect[F]): List[ComplianceTest[F]] = clientTests(router, service) ++ serverTests(router, service) + def allTests[F[_], Alg[_[_, _, _, _, _]]]( + router: Router[F] with ReverseRouter[F], + service: Service[Alg] + )(implicit ce: CompatEffect[F]): List[ComplianceTest[F]] = + clientAndServerTests(router, service) ++ malformedRequestTests( + router, + service + ) + } diff --git a/modules/compliance-tests/src/smithy4s/compliancetests/internals/Assertions.scala b/modules/compliance-tests/src/smithy4s/compliancetests/internals/Assertions.scala index e052809c4..94432aa31 100644 --- a/modules/compliance-tests/src/smithy4s/compliancetests/internals/Assertions.scala +++ b/modules/compliance-tests/src/smithy4s/compliancetests/internals/Assertions.scala @@ -51,23 +51,17 @@ private[internals] object assert { } } - def neql[A: Eq](expected: A, actual: A): ComplianceResult = { - if (expected =!= actual) { - success - } else { - fail( - s"This test passed when it was supposed to fail, Actual value: ${pprint - .apply(actual)} was equal to ${pprint.apply(expected)}." - ) - } - } - - def eql[A: Eq](expected: A, actual: A): ComplianceResult = { - if (expected === actual) { + def eql[A: Eq]( + result: A, + testCase: A, + prefix: String = "" + ): ComplianceResult = { + if (result === testCase) { success } else { fail( - s"Actual value: ${pprint.apply(actual)} was not equal to ${pprint.apply(expected)}." + s"$prefix the result value: ${pprint.apply(result)} was not equal to the expected TestCase value ${pprint + .apply(testCase)}." ) } } @@ -84,6 +78,19 @@ private[internals] object assert { } } + def regexEql( + expected: String, + actual: String + ): ComplianceResult = { + if (actual.matches(expected)) { + success + } else { + fail( + s"Actual value: ${pprint.apply(actual)} was not equal to ${pprint.apply(expected)}." + ) + } + } + private def headersExistenceCheck( headers: Headers, expected: Either[Option[List[String]], Option[List[String]]] @@ -101,7 +108,7 @@ private[internals] object assert { }.combineAll } } - private def headersCheck( + def headersCheck( headers: Headers, expected: Option[Map[String, String]] ) = { diff --git a/modules/compliance-tests/src/smithy4s/compliancetests/internals/ClientHttpComplianceTestCase.scala b/modules/compliance-tests/src/smithy4s/compliancetests/internals/ClientHttpComplianceTestCase.scala index 960c2c0b7..b1220160c 100644 --- a/modules/compliance-tests/src/smithy4s/compliancetests/internals/ClientHttpComplianceTestCase.scala +++ b/modules/compliance-tests/src/smithy4s/compliancetests/internals/ClientHttpComplianceTestCase.scala @@ -31,7 +31,6 @@ import smithy4s.Document import smithy4s.http.PayloadError import smithy4s.Service import cats.Eq - import scala.concurrent.duration._ import smithy4s.http.HttpMediaType import org.http4s.MediaType @@ -56,14 +55,13 @@ private[compliancetests] class ClientHttpComplianceTestCase[ testCase: HttpRequestTestCase ): F[ComplianceResult] = { - val bodyAssert = testCase.body - .map { expectedBody => - request.bodyText.compile.string.map { responseBody => - assert.bodyEql(responseBody, expectedBody, testCase.bodyMediaType) - - } - } - .getOrElse(assert.success.pure[F]) + val bodyAssert = request.bodyText.compile.string.map { responseBody => + assert.bodyEql( + responseBody, + testCase.body.getOrElse(""), + testCase.bodyMediaType + ) + } val expectedUri = baseUri .withPath( @@ -74,14 +72,20 @@ private[compliancetests] class ClientHttpComplianceTestCase[ ) val pathAssert = - assert.eql(expectedUri.path.renderString, request.uri.path.renderString) + assert.eql( + expectedUri.path.renderString, + request.uri.path.renderString, + "path test :" + ) val queryAssert = assert.eql( expectedUri.query.renderString, - request.uri.query.renderString + request.uri.query.renderString, + "query test :" ) val methodAssert = assert.eql( testCase.method.toLowerCase(), - request.method.name.toLowerCase() + request.method.name.toLowerCase(), + "method test :" ) val ioAsserts: List[F[ComplianceResult]] = bodyAssert +: List( @@ -103,7 +107,9 @@ private[compliancetests] class ClientHttpComplianceTestCase[ val revisedSchema = mapAllTimestampsToEpoch(endpoint.input.awsHintMask) val inputFromDocument = Document.Decoder.fromSchema(revisedSchema) ComplianceTest[F]( - name = endpoint.id.toString + "(client|request): " + testCase.id, + testCase.id, + endpoint.id, + testCase.tags, run = { val input = inputFromDocument .decode(testCase.params.getOrElse(Document.obj())) @@ -152,7 +158,9 @@ private[compliancetests] class ClientHttpComplianceTestCase[ val dummyInput = DefaultSchemaVisitor(endpoint.input) ComplianceTest[F]( - name = endpoint.id.toString + "(client|response): " + testCase.id, + testCase.id, + endpoint.id, + testCase.tags, run = { val revisedSchema = mapAllTimestampsToEpoch(endpoint.output.awsHintMask) implicit val outputEq: Eq[O] = @@ -211,8 +219,8 @@ private[compliancetests] class ClientHttpComplianceTestCase[ .apply(endpoint.wrap(dummyInput)) res.map { output => assert.eql( - expectedOutput, - output + output, + expectedOutput ) } } diff --git a/modules/compliance-tests/src/smithy4s/compliancetests/internals/MalformedRequestComplianceTestCase.scala b/modules/compliance-tests/src/smithy4s/compliancetests/internals/MalformedRequestComplianceTestCase.scala new file mode 100644 index 000000000..a5752f702 --- /dev/null +++ b/modules/compliance-tests/src/smithy4s/compliancetests/internals/MalformedRequestComplianceTestCase.scala @@ -0,0 +1,242 @@ +/* + * Copyright 2021-2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithy4s.compliancetests +package internals + +import cats.implicits._ +import org.http4s._ +import smithy.test._ +import smithy4s.Service +import smithy4s.kinds._ +import software.amazon.smithy.utils.SimpleCodeWriter + +private[compliancetests] class MalformedRequestComplianceTestCase[ + F[_], + Alg[_[_, _, _, _, _]] +]( + router: Router[F], + serviceInstance: Service[Alg] +)(implicit + ce: CompatEffect[F] +) { + + import ce._ + import org.http4s.implicits._ + import router._ + + private[compliancetests] val originalService: Service[Alg] = serviceInstance + private val baseUri = uri"http://localhost/" + + private def makeRequest( + baseUri: Uri, + testCase: HttpMalformedRequestTestCase + ): Request[F] = { + val req = testCase.request + val expectedHeaders = parseHeaders(req.headers) + val expectedMethod = Method + .fromString(req.method) + .getOrElse(sys.error("Invalid method")) + + val expectedUri = baseUri + .withPath( + Uri.Path.unsafeFromString(req.uri).addEndsWithSlash + ) + .withMultiValueQueryParams( + parseQueryParams(req.queryParams) + ) + val body = + req.body + .map(b => fs2.Stream.emit(b).through(ce.utf8Encode)) + .getOrElse(fs2.Stream.empty) + + Request[F]( + method = expectedMethod, + uri = expectedUri, + headers = expectedHeaders, + body = body + ) + } + + private[compliancetests] def malformedRequestTest[I, E, O, SE, SO]( + endpoint: originalService.Endpoint[I, E, O, SE, SO], + testCase: HttpMalformedRequestTestCase + ): ComplianceTest[F] = { + ComplianceTest[F]( + testCase.id, + endpoint.id, + testCase.tags, + run = { + val fakeImpl: FunctorAlgebra[Alg, F] = + originalService.fromPolyFunction[Kind1[F]#toKind5]( + new originalService.FunctorInterpreter[F] { + def apply[I_, E_, O_, SE_, SO_]( + op: originalService.Operation[I_, E_, O_, SE_, SO_] + ): F[O_] = { + raiseError(new IntendedShortCircuit) + } + } + ) + + routes(fakeImpl)(originalService) + .use { server => + server.orNotFound + .run(makeRequest(baseUri, testCase)) + .attemptNarrow[IntendedShortCircuit] + .flatMap { + case Left(_) => + assert + .fail( + s"Expected a Error Response, but got a IntendedShortCircuit error" + ) + .pure[F] + case Right(resp) => + resp.body + .through(utf8Decode) + .compile + .foldMonoid + .tupleRight(resp.status) + .tupleRight(resp.headers) + .map { case ((actualBody, status), headers) => + val response = testCase.response + val bodyAssert = response.body + .map(malformedResponseBodyDefinition => { + malformedResponseBodyDefinition.assertion match { + case HttpMalformedResponseBodyAssertion + .ContentsCase(contents) => + assert.bodyEql( + contents, + actualBody, + Some( + malformedResponseBodyDefinition.mediaType + ) + ) + case HttpMalformedResponseBodyAssertion + .MessageRegexCase(messageRegex) => + assert.regexEql(messageRegex, actualBody) + } + }) + val assertions = + bodyAssert.toList :+ + assert.headersCheck(headers, response.headers) :+ + assert.eql(status.code, response.code) + assertions.combineAll + } + } + } + } + ) + + } + + /** + * From the docs: + * The lists of values for each key must be identical in length. One test permutation is generated for each index the parameter lists. + * For example, parameters with 5 values for each key will generate 5 tests in total. + */ + + private def interpolateRequest( + request: HttpMalformedRequestDefinition, + writer: SimpleCodeWriter + ): HttpMalformedRequestDefinition = { + HttpMalformedRequestDefinition( + method = writer.format(request.method), + uri = writer.format(request.uri), + host = request.host.map(writer.format(_)), + queryParams = request.queryParams.map(_.map(writer.format(_))), + headers = interpolateHeaders(request.headers, writer), + body = request.body.map(writer.format(_)) + ) + } + + /** + * Interpolate the headers, but don't interpolate the error type as this is specific to Amazon + * @param headers + * @param writer + * @return + */ + private def interpolateHeaders( + headers: Option[Map[String, String]], + writer: SimpleCodeWriter + ): Option[Map[String, String]] = { + headers.map(_.filterNot(_._1.equalsIgnoreCase("x-amzn-errortype")).map { + case (key, value) => + (writer.format(key), writer.format(value)) + }) + } + + private def interpolateResponse( + response: HttpMalformedResponseDefinition, + writer: SimpleCodeWriter + ): HttpMalformedResponseDefinition = { + + HttpMalformedResponseDefinition( + code = response.code, + headers = interpolateHeaders(response.headers, writer), + body = response.body.map { body => + body.copy( + mediaType = writer.format(body.mediaType), + assertion = body.assertion match { + case HttpMalformedResponseBodyAssertion.ContentsCase(contents) => + HttpMalformedResponseBodyAssertion.ContentsCase( + writer.format(contents) + ) + case HttpMalformedResponseBodyAssertion.MessageRegexCase( + messageRegex + ) => + HttpMalformedResponseBodyAssertion.MessageRegexCase( + writer.format(messageRegex) + ) + } + ) + } + ) + + } + private def generateMalformedRequestTests( + malformedRequestTestCase: HttpMalformedRequestTestCase + ): List[HttpMalformedRequestTestCase] = { + malformedRequestTestCase.testParameters + .flatMap(_.get("value")) + .fold(List(malformedRequestTestCase)) { value => + value.map(arg => { + val writer = new SimpleCodeWriter() + writer.putContext("value", arg) + malformedRequestTestCase.copy( + request = + interpolateRequest(malformedRequestTestCase.request, writer), + response = + interpolateResponse(malformedRequestTestCase.response, writer), + protocol = protocolTag.id.toString(), + tags = Some(List(NonEmptyString(arg))) + ) + }) + } + } + + def malformedRequestTests(): List[ComplianceTest[F]] = { + originalService.endpoints.flatMap { case endpoint => + endpoint.hints + .get(HttpMalformedRequestTests) + .map(_.value) + .getOrElse(Nil) + .flatMap(generateMalformedRequestTests) + .filter(_.protocol == protocolTag.id.toString()) + .map(tc => malformedRequestTest(endpoint, tc)) + + } + } +} diff --git a/modules/compliance-tests/src/smithy4s/compliancetests/internals/ServerHttpComplianceTestCase.scala b/modules/compliance-tests/src/smithy4s/compliancetests/internals/ServerHttpComplianceTestCase.scala index cb20c9222..ec37aa4aa 100644 --- a/modules/compliance-tests/src/smithy4s/compliancetests/internals/ServerHttpComplianceTestCase.scala +++ b/modules/compliance-tests/src/smithy4s/compliancetests/internals/ServerHttpComplianceTestCase.scala @@ -93,7 +93,9 @@ private[compliancetests] class ServerHttpComplianceTestCase[ implicit val inputEq: Eq[I] = EqSchemaVisitor(revisedSchema) val inputFromDocument = Document.Decoder.fromSchema(revisedSchema) ComplianceTest[F]( - name = endpoint.id.toString + "(server|request): " + testCase.id, + testCase.id, + endpoint.id, + testCase.tags, run = { deferred[I].flatMap { inputDeferred => val fakeImpl: FunctorAlgebra[Alg, F] = @@ -151,7 +153,9 @@ private[compliancetests] class ServerHttpComplianceTestCase[ ): ComplianceTest[F] = { ComplianceTest[F]( - name = endpoint.id.toString + "(server|response): " + testCase.id, + testCase.id, + endpoint.id, + testCase.tags, run = { val (ammendedService, syntheticRequest) = prepareService(endpoint) diff --git a/modules/http4s/test/src-ce3/smithy4s/compliancetests/ProtocolComplianceTest.scala b/modules/http4s/test/src-ce3/smithy4s/compliancetests/ProtocolComplianceTest.scala index d3009a935..781f67889 100644 --- a/modules/http4s/test/src-ce3/smithy4s/compliancetests/ProtocolComplianceTest.scala +++ b/modules/http4s/test/src-ce3/smithy4s/compliancetests/ProtocolComplianceTest.scala @@ -46,6 +46,7 @@ import smithy4s.schema.Schema.document object ProtocolComplianceTest extends EffectSuite[IO] with BaseCatsSuite { implicit protected def effectCompat: EffectCompat[IO] = CatsUnsafeRun + def getSuite: EffectSuite[IO] = this def spec(args: List[String]): fs2.Stream[IO, TestOutcome] = { @@ -112,7 +113,10 @@ object ProtocolComplianceTest extends EffectSuite[IO] with BaseCatsSuite { .toList .flatMap(wrapper => { HttpProtocolCompliance - .clientAndServerTests(SimpleRestJsonIntegration, wrapper.service) + .allTests( + SimpleRestJsonIntegration, + wrapper.service + ) }) } @@ -134,7 +138,7 @@ object ProtocolComplianceTest extends EffectSuite[IO] with BaseCatsSuite { } private def runInWeaver(tc: ComplianceTest[IO]): IO[TestOutcome] = Test( - tc.name, + tc.show, tc.run .map[Expectations] { case Left(value) => @@ -149,4 +153,9 @@ object ProtocolComplianceTest extends EffectSuite[IO] with BaseCatsSuite { } ) + def expectFailure( + res: ComplianceTest.ComplianceResult + ): Expectations = { + res.foldMap(_ => Expectations.Helpers.success) + } } diff --git a/modules/transformers/src/SimpleRestJsonProtocolTransformer.scala b/modules/transformers/src/SimpleRestJsonProtocolTransformer.scala new file mode 100644 index 000000000..ec545c68d --- /dev/null +++ b/modules/transformers/src/SimpleRestJsonProtocolTransformer.scala @@ -0,0 +1,37 @@ +package smithy4s.transformers + +import alloy.SimpleRestJsonTrait +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.build.{ProjectionTransformer, TransformContext} +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.{Shape, ShapeId} +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.protocoltests.traits._ +import java.util.function.BiFunction +import scala.jdk.CollectionConverters.{CollectionHasAsScala, SeqHasAsJava} +final class SimpleRestJsonProtocolTransformer extends ProjectionTransformer { + override def getName: String = "ProtocolTransformer" + + def transform(ctx: TransformContext): Model = { + val traitMapper: BiFunction[Shape, Trait, Trait] = (_: Shape, theTrait: Trait) => { + theTrait match { + case _:RestJson1Trait => new SimpleRestJsonTrait() + case c: HttpRequestTestsTrait => new HttpRequestTestsTrait(c.getSourceLocation, c.getTestCases.asScala.toList.map { + case req: HttpRequestTestCase => + if (req.getProtocol == ShapeId.from("aws.protocols#restJson1")) + req.toBuilder.protocol(ShapeId.from("alloy#simpleRestJson")).build() + else req + }.asJava) + case c: HttpResponseTestsTrait => new HttpResponseTestsTrait(c.getSourceLocation, c.getTestCases.asScala.toList.map { + case res: HttpResponseTestCase => + if (res.getProtocol == ShapeId.from("aws.protocols#restJson1")) + res.toBuilder.protocol(ShapeId.from("alloy#simpleRestJson")).build() + else res + }.asJava) + case _ => theTrait + } + } + ctx.getTransformer().mapTraits(ctx.getModel(), traitMapper) + } + +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index e4ac4f497..c3aeceeb2 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -24,6 +24,7 @@ object Dependencies { val model = org % "smithy-model" % smithyVersion val testTraits = org % "smithy-protocol-test-traits" % smithyVersion val build = org % "smithy-build" % smithyVersion + val utils = org % "smithy-utils" % smithyVersion val awsTraits = org % "smithy-aws-traits" % smithyVersion val waiters = org % "smithy-waiters" % smithyVersion }