Skip to content

Commit

Permalink
Support http forward (temp impl)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoii committed Oct 9, 2023
1 parent ea1cc79 commit ad6384c
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2023 Mamoe Technologies and contributors.
*
* 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
* Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
*
* https://github.com/mamoe/mirai/blob/master/LICENSE
*/

package net.mamoe.mirai.api.http.adapter.http.plugin

import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.util.reflect.*


internal val HttpForwardAttributeKey = AttributeKey<HttpForwardContext>("HttpForward")
val HttpForwardPhase = PipelinePhase("Forward")
val HttpForward = createApplicationPlugin("HttpForward", ::HttpForwardConfig) {
application.insertPhaseAfter(ApplicationCallPipeline.Call, HttpForwardPhase)

application.intercept(HttpForwardPhase) {
val forwardContext = call.attributes.getOrNull(HttpForwardAttributeKey)
if (forwardContext != null && !forwardContext.forwarded) {
forwardContext.forwarded = true
forwardContext.convertors = this@createApplicationPlugin.pluginConfig.getConvertors()
finish()
application.execute(
ApplicationForwardCall(call, forwardContext)
)
}
}
}

typealias BodyConvertor = (Any, TypeInfo) -> Any?

class HttpForwardConfig {
private val convertors: MutableList<BodyConvertor> = mutableListOf(DefaultBodyConvertor)
fun addConvertor(convertor: BodyConvertor) {
convertors.add(convertor)
}

internal fun getConvertors(): List<BodyConvertor> = convertors
}

val DefaultBodyConvertor: (Any, TypeInfo) -> Any? = { body, typeInfo ->
if (typeInfo.type.isInstance(body)) body else null
}

internal data class HttpForwardContext(val router: String, val body: Any?) {
var forwarded = false
var convertors = emptyList<BodyConvertor>()
}

fun ApplicationCall.forward(forward: String) {
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, null))
}

fun ApplicationCall.forward(forward: String, body: Any) {
attributes.put(HttpForwardAttributeKey, HttpForwardContext(forward, body))
}

internal fun forwardReceivePipeline(convertors: List<BodyConvertor>, body: Any): ApplicationReceivePipeline =
ApplicationReceivePipeline().apply {
intercept(ApplicationReceivePipeline.Transform) {
proceedWith(convertors.firstNotNullOfOrNull { it.invoke(body, context.receiveType) }
?: throw NoSuchElementException("fuck"))
}
}

internal class ApplicationForwardCall(
val delegate: ApplicationCall, val context: HttpForwardContext
) : ApplicationCall by delegate {
override val request: ApplicationRequest = DelegateApplicationRequest(this, context.router, context.body)
}

internal class DelegateApplicationRequest(
override val call: ApplicationForwardCall, forward: String, body: Any?
) : ApplicationRequest by call.delegate.request {
private val _pipeline by lazy {
body?.let { forwardReceivePipeline(call.context.convertors, it) } ?: call.delegate.request.pipeline
}
override val local = DelegateRequestConnectionPoint(call.delegate.request.local, forward)
override val pipeline: ApplicationReceivePipeline = _pipeline
}

internal class DelegateRequestConnectionPoint(
private val delegate: RequestConnectionPoint, override val uri: String
) : RequestConnectionPoint by delegate

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.testing.*
import net.mamoe.mirai.api.http.adapter.http.support.forward
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.serializer
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.LongTargetDTO
import net.mamoe.mirai.api.http.adapter.internal.dto.parameter.NudgeDTO
import net.mamoe.mirai.api.http.adapter.internal.serializer.BuiltinJsonSerializer
Expand Down Expand Up @@ -127,4 +130,50 @@ class HttpForwardTest {
assertEquals("321", it.bodyAsText())
}
}


@Serializable
private data class NeatedDto(
val router: String,
val body: JsonElement,
)

@OptIn(InternalSerializationApi::class)
@Test
fun testPostRequestForwardNestedBody() = testApplication {
// No need for DoubleReceive
install(GlobalExceptionHandler) { printTrace = true }
install(DoubleReceive)
install(HttpRouterMonitor)
val json = BuiltinJsonSerializer.buildJson()
install(ContentNegotiation) { json(json) }
install(HttpForward) {
addConvertor { body, typeInfo ->
if (body is JsonElement) {
json.decodeFromJsonElement(typeInfo.type.serializer(), body)
} else null
}
}

routing {
post("/test") {
val receive = call.receive<NeatedDto>()
assertEquals("/forward", receive.router)
call.forward("/forward", receive.body)
}

post("/forward") {
val receive = call.receive<LongTargetDTO>()
call.respondText(receive.target.toString())
}
}

client.post("/test") {
contentType(ContentType.Application.Json)
setBody("""{"router":"/forward","body":{"target":321}}""")
}.also {
assertEquals(HttpStatusCode.OK, it.status)
assertEquals("321", it.bodyAsText())
}
}
}

0 comments on commit ad6384c

Please sign in to comment.