Skip to content

Commit

Permalink
Add support for WebSockets in Akka HTTP Adapter (#219)
Browse files Browse the repository at this point in the history
* Add support for WebSockets in Akka Adapter

* Clean up map on stop
  • Loading branch information
ghostdogpr authored Feb 17, 2020
1 parent 4609ef2 commit 712f74c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 12 deletions.
113 changes: 107 additions & 6 deletions akka-http/src/main/scala/caliban/AkkaHttpAdapter.scala
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
package caliban

import scala.concurrent.ExecutionContext
import akka.http.scaladsl.model.MediaTypes.`application/json`
import akka.http.scaladsl.model.ws.{ Message, TextMessage }
import akka.http.scaladsl.model.{ HttpEntity, HttpResponse, StatusCodes }
import akka.http.scaladsl.server.Directives.complete
import akka.http.scaladsl.server.{ Route, StandardRoute }
import akka.stream.scaladsl.{ Flow, Sink, Source, SourceQueueWithComplete }
import akka.stream.{ Materializer, OverflowStrategy, QueueOfferResult }
import caliban.ResponseValue.{ ObjectValue, StreamValue }
import caliban.Value.NullValue
import de.heikoseeberger.akkahttpcirce.FailFastCirceSupport
import io.circe.Decoder.Result
import io.circe.Json
import io.circe.parser._
import io.circe.syntax._
import zio.{ Runtime, URIO }

import scala.concurrent.ExecutionContext
import zio.{ Fiber, IO, Ref, Runtime, Task, URIO }

object AkkaHttpAdapter extends FailFastCirceSupport {

private def execute[R, E](interpreter: GraphQLInterpreter[R, E], query: GraphQLRequest): URIO[R, GraphQLResponse[E]] =
interpreter.execute(query.query, query.operationName, query.variables.getOrElse(Map()))

private def executeHttpResponse[R, E](
interpreter: GraphQLInterpreter[R, E],
request: GraphQLRequest
): URIO[R, HttpResponse] =
interpreter
.execute(request.query, request.operationName, request.variables.getOrElse(Map()))
execute(interpreter, request)
.foldCause(cause => GraphQLResponse(NullValue, cause.defects).asJson, _.asJson)
.map(gqlResult => HttpResponse(StatusCodes.OK, entity = HttpEntity(`application/json`, gqlResult.toString())))

def getGraphQLRequest(query: String, op: Option[String], vars: Option[String]): Result[GraphQLRequest] = {
import io.circe.parser._
val variablesJs = vars.flatMap(parse(_).toOption)
val fields = List("query" -> Json.fromString(query)) ++
op.map(o => "operationName" -> Json.fromString(o)) ++
Expand Down Expand Up @@ -60,4 +65,100 @@ object AkkaHttpAdapter extends FailFastCirceSupport {
entity(as[GraphQLRequest]) { completeRequest(interpreter) }
}
}

def makeWebSocketService[R, E](
interpreter: GraphQLInterpreter[R, E]
)(implicit ec: ExecutionContext, runtime: Runtime[R], materializer: Materializer): Route = {
def sendMessage(
sendQueue: SourceQueueWithComplete[Message],
id: String,
data: ResponseValue,
errors: List[E]
): Task[QueueOfferResult] =
IO.fromFuture(
_ =>
sendQueue.offer(
TextMessage(
Json
.obj(
"id" -> Json.fromString(id),
"type" -> Json.fromString("data"),
"payload" -> GraphQLResponse(data, errors).asJson
)
.noSpaces
)
)
)

import akka.http.scaladsl.server.Directives._

get {
extractUpgradeToWebSocket { upgrade =>
val (queue, source) = Source.queue[Message](0, OverflowStrategy.fail).preMaterialize()
val subscriptions = runtime.unsafeRun(Ref.make(Map.empty[String, Fiber[Throwable, Unit]]))
val sink = Sink.foreach[Message] {
case TextMessage.Strict(text) =>
val io = for {
msg <- Task.fromEither(decode[Json](text))
msgType = msg.hcursor.downField("type").success.flatMap(_.value.asString).getOrElse("")
_ <- IO.whenCase(msgType) {
case "connection_init" =>
IO.fromFuture(_ => queue.offer(TextMessage("""{"type":"connection_ack"}""")))
case "connection_terminate" =>
IO.effect(queue.complete())
case "start" =>
val payload = msg.hcursor.downField("payload")
val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("")
Task.whenCase(payload.downField("query").success.flatMap(_.value.asString)) {
case Some(query) =>
val operationName = payload.downField("operationName").success.flatMap(_.value.asString)
(for {
result <- execute(interpreter, GraphQLRequest(query, operationName, None))
_ <- result.data match {
case ObjectValue((fieldName, StreamValue(stream)) :: Nil) =>
stream.foreach { item =>
sendMessage(queue, id, ObjectValue(List(fieldName -> item)), result.errors)
}.fork.flatMap(fiber => subscriptions.update(_.updated(id, fiber)))
case other =>
sendMessage(queue, id, other, result.errors) *> IO.fromFuture(
_ => queue.offer(TextMessage(s"""{"type":"complete","id":"$id"}"""))
)
}
} yield ()).catchAll(
error =>
IO.fromFuture(
_ =>
queue.offer(
TextMessage(
Json
.obj(
"id" -> Json.fromString(id),
"type" -> Json.fromString("complete"),
"payload" -> Json.fromString(error.toString)
)
.noSpaces
)
)
)
)
}
case "stop" =>
val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("")
subscriptions
.modify(map => (map.get(id), map - id))
.flatMap(fiber => IO.whenCase(fiber) { case Some(fiber) => fiber.interrupt })
}
} yield ()
runtime.unsafeRun(io)
case _ => ()
}

val flow = Flow.fromSinkAndSource(sink, source).watchTermination() { (_, f) =>
f.onComplete(_ => runtime.unsafeRun(subscriptions.get.flatMap(m => IO.foreach(m.values)(_.interrupt).unit)))
}

complete(upgrade.handleMessages(flow, subprotocol = Some("graphql-ws")))
}
}
}
}
23 changes: 18 additions & 5 deletions examples/src/main/scala/caliban/akkahttp/ExampleApp.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package caliban.akkahttp

import scala.language.postfixOps
import scala.io.StdIn
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
Expand All @@ -8,9 +9,12 @@ import caliban.ExampleData.{ sampleCharacters, Character, CharacterArgs, Charact
import caliban.GraphQL.graphQL
import caliban.schema.Annotations.{ GQLDeprecated, GQLDescription }
import caliban.schema.GenericSchema
import caliban.{ AkkaHttpAdapter, ExampleService, RootResolver }
import caliban.wrappers.ApolloTracing.apolloTracing
import caliban.wrappers.Wrappers._
import caliban.{ AkkaHttpAdapter, ExampleService, GraphQL, RootResolver }
import zio.clock.Clock
import zio.console.Console
import zio.duration._
import zio.stream.ZStream
import zio.{ DefaultRuntime, URIO }

Expand All @@ -34,9 +38,7 @@ object ExampleApp extends App with GenericSchema[Console with Clock] {
case class Mutations(deleteCharacter: CharacterArgs => URIO[Console, Boolean])
case class Subscriptions(characterDeleted: ZStream[Console, Nothing, String])

val service = defaultRuntime.unsafeRun(ExampleService.make(sampleCharacters))

val interpreter =
def makeApi(service: ExampleService): GraphQL[Console with Clock] =
graphQL(
RootResolver(
Queries(
Expand All @@ -46,7 +48,16 @@ object ExampleApp extends App with GenericSchema[Console with Clock] {
Mutations(args => service.deleteCharacter(args.name)),
Subscriptions(service.deletedEvents)
)
).interpreter
) @@
maxFields(200) @@ // query analyzer that limit query fields
maxDepth(30) @@ // query analyzer that limit query depth
timeout(3 seconds) @@ // wrapper that fails slow queries
printSlowQueries(500 millis) @@ // wrapper that logs slow queries
apolloTracing // wrapper for https://github.com/apollographql/apollo-tracing

val service = defaultRuntime.unsafeRun(ExampleService.make(sampleCharacters))

val interpreter = makeApi(service).interpreter

/**
* curl -X POST \
Expand All @@ -60,6 +71,8 @@ object ExampleApp extends App with GenericSchema[Console with Clock] {
val route =
path("api" / "graphql") {
AkkaHttpAdapter.makeHttpService(interpreter)
} ~ path("ws" / "graphql") {
AkkaHttpAdapter.makeWebSocketService(interpreter)
} ~ path("graphiql") {
getFromResource("graphiql.html")
}
Expand Down
4 changes: 3 additions & 1 deletion http4s/src/main/scala/caliban/Http4sAdapter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ object Http4sAdapter {
}
case "stop" =>
val id = msg.hcursor.downField("id").success.flatMap(_.value.asString).getOrElse("")
subscriptions.get.flatMap(map => IO.whenCase(map.get(id)) { case Some(fiber) => fiber.interrupt })
subscriptions
.modify(map => (map.get(id), map - id))
.flatMap(fiber => IO.whenCase(fiber) { case Some(fiber) => fiber.interrupt })
}
} yield ()
}
Expand Down

0 comments on commit 712f74c

Please sign in to comment.