diff --git a/Realtime/build.gradle.kts b/Realtime/build.gradle.kts index 4dfd1297..0fb87c02 100644 --- a/Realtime/build.gradle.kts +++ b/Realtime/build.gradle.kts @@ -21,6 +21,14 @@ kotlin { api(libs.ktor.client.websockets) } } + val commonTest by getting { + dependencies { + implementation(libs.ktor.server.host) + implementation(libs.ktor.server.websockets) + implementation(project(":test-common")) + implementation(libs.bundles.testing) + } + } } } diff --git a/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/Realtime.kt b/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/Realtime.kt index c8760ec0..c993f203 100644 --- a/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/Realtime.kt +++ b/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/Realtime.kt @@ -12,6 +12,7 @@ import io.github.jan.supabase.plugins.MainPlugin import io.github.jan.supabase.plugins.SupabasePluginProvider import io.github.jan.supabase.serializer.KotlinXSerializer import io.github.jan.supabase.supabaseJson +import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession import io.ktor.client.plugins.websocket.WebSockets import io.ktor.serialization.kotlinx.KotlinxWebsocketSerializationConverter import kotlinx.coroutines.flow.StateFlow @@ -113,6 +114,7 @@ sealed interface Realtime : MainPlugin, CustomSerializationPlug var disconnectOnSessionLoss: Boolean = true, var connectOnSubscribe: Boolean = true, var disconnectOnNoSubscriptions: Boolean = true, + var websocketSessionProvider: (suspend () -> DefaultClientWebSocketSession)? = null, @Deprecated("This property is deprecated and will be removed in a future version.") var eventsPerSecond: Int = 10, ): MainConfig(), CustomSerializationConfig { diff --git a/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/RealtimeImpl.kt b/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/RealtimeImpl.kt index 0f0b5f31..556aa1a2 100644 --- a/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/RealtimeImpl.kt +++ b/Realtime/src/commonMain/kotlin/io/github/jan/supabase/realtime/RealtimeImpl.kt @@ -36,6 +36,9 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.serialization.json.buildJsonObject +import kotlin.collections.component1 +import kotlin.collections.component2 +import kotlin.collections.set import kotlin.time.Duration.Companion.milliseconds @PublishedApi internal class RealtimeImpl(override val supabaseClient: SupabaseClient, override val config: Realtime.Config) : Realtime { @@ -76,7 +79,7 @@ import kotlin.time.Duration.Companion.milliseconds _status.value = Realtime.Status.CONNECTING val realtimeUrl = websocketUrl try { - ws = supabaseClient.httpClient.webSocketSession(realtimeUrl) + ws = config.websocketSessionProvider?.invoke() ?: supabaseClient.httpClient.webSocketSession(realtimeUrl) _status.value = Realtime.Status.CONNECTED Realtime.logger.i { "Connected to realtime websocket!" } listenForMessages() @@ -234,7 +237,7 @@ import kotlin.time.Duration.Companion.milliseconds } override suspend fun close() { - ws?.cancel() + disconnect() } override suspend fun block() { diff --git a/Realtime/src/commonTest/kotlin/FlowUtils.kt b/Realtime/src/commonTest/kotlin/FlowUtils.kt new file mode 100644 index 00000000..7088900a --- /dev/null +++ b/Realtime/src/commonTest/kotlin/FlowUtils.kt @@ -0,0 +1,5 @@ +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.first + +suspend inline fun Flow.waitForValue(value: T) = filter { it == value }.first() \ No newline at end of file diff --git a/Realtime/src/commonTest/kotlin/RealtimeChannelTest.kt b/Realtime/src/commonTest/kotlin/RealtimeChannelTest.kt new file mode 100644 index 00000000..f7edc6b1 --- /dev/null +++ b/Realtime/src/commonTest/kotlin/RealtimeChannelTest.kt @@ -0,0 +1,193 @@ +import io.github.jan.supabase.gotrue.Auth +import io.github.jan.supabase.gotrue.auth +import io.github.jan.supabase.gotrue.minimalSettings +import io.github.jan.supabase.realtime.Realtime +import io.github.jan.supabase.realtime.RealtimeChannel +import io.github.jan.supabase.realtime.RealtimeChannel.Companion.CHANNEL_EVENT_REPLY +import io.github.jan.supabase.realtime.RealtimeChannel.Companion.CHANNEL_EVENT_SYSTEM +import io.github.jan.supabase.realtime.RealtimeJoinPayload +import io.github.jan.supabase.realtime.RealtimeMessage +import io.github.jan.supabase.realtime.channel +import io.github.jan.supabase.realtime.realtime +import io.ktor.server.websocket.receiveDeserialized +import io.ktor.server.websocket.sendSerialized +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.put +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class RealtimeChannelTest { + + @Test + fun testConnectOnSubscribeDisabled() { + createTestClient( + wsHandler = { + //Does not matter for this test + }, + supabaseHandler = { + val channel = it.channel("") + assertFailsWith() { + channel.subscribe() + } + }, + realtimeConfig = { + connectOnSubscribe = false + } + ) + } + + @Test + fun testConnectOnSubscribeEnabled() { + createTestClient( + wsHandler = { + incoming.receive() + }, + supabaseHandler = { + val channel = it.channel("") + channel.subscribe(false) + assertEquals(Realtime.Status.CONNECTED, it.realtime.status.value) + } + ) + } + + @Test + fun testChannelStatusWithoutPostgres() { + val channelId = "channelId" + createTestClient( + wsHandler = { + incoming.receive() + sendSerialized(RealtimeMessage("realtime:$channelId", CHANNEL_EVENT_SYSTEM, buildJsonObject { put("status", "ok") }, "")) + incoming.receive() + sendSerialized(RealtimeMessage("realtime:$channelId", CHANNEL_EVENT_REPLY, buildJsonObject { put("status", "ok") }, "")) + }, + supabaseHandler = { + val channel = it.channel("channelId") + assertEquals(channel.status.value, RealtimeChannel.Status.UNSUBSCRIBED) + assertEquals(it.realtime.status.value, Realtime.Status.DISCONNECTED) + channel.subscribe(blockUntilSubscribed = true) + assertEquals(channel.status.value, RealtimeChannel.Status.SUBSCRIBED) + channel.unsubscribe() + assertEquals(channel.status.value, RealtimeChannel.Status.UNSUBSCRIBING) + assertEquals(RealtimeChannel.Status.UNSUBSCRIBED, channel.status.waitForValue(RealtimeChannel.Status.UNSUBSCRIBED)) + }, + ) + } + + @Test + fun testSendingPayloadWithoutJWT() { + val expectedChannelId = "channelId" + val expectedIsPrivate = true + val expectedReceiveOwnBroadcasts = true + val expectedAcknowledge = true + val expectedPresenceKey = "presenceKey" + createTestClient( + wsHandler = { + val message = this.receiveDeserialized() + val payload = Json.decodeFromJsonElement(message.payload) + assertEquals("realtime:$expectedChannelId", message.topic) + assertEquals(expectedIsPrivate, payload.config.isPrivate) + assertEquals(expectedReceiveOwnBroadcasts, payload.config.broadcast.receiveOwnBroadcasts) + assertEquals(expectedAcknowledge, payload.config.broadcast.acknowledgeBroadcasts) + assertEquals(expectedPresenceKey, payload.config.presence.key) + }, + supabaseHandler = { + val channel = it.channel("channelId") { + isPrivate = expectedIsPrivate + broadcast { + receiveOwnBroadcasts = expectedReceiveOwnBroadcasts + acknowledgeBroadcasts = expectedAcknowledge + } + presence { + key = expectedPresenceKey + } + } + channel.subscribe() + } + ) + } + + @Test + fun testSendingPayloadWithAuthJWT() { + val expectedAuthToken = "authToken" + createTestClient( + wsHandler = { + val message = this.receiveDeserialized() + assertEquals(expectedAuthToken, message.payload["access_token"]?.jsonPrimitive?.content) + }, + supabaseHandler = { + it.auth.importAuthToken(expectedAuthToken) + val channel = it.channel("channelId") + channel.subscribe() + }, + supabaseConfig = { + install(Auth) { + minimalSettings() + } + } + ) + } + + @Test + fun testSendingPayloadWithCustomJWT() { + val expectedAuthToken = "authToken" + createTestClient( + wsHandler = { + val message = this.receiveDeserialized() + assertEquals(expectedAuthToken, message.payload["access_token"]?.jsonPrimitive?.content) + }, + supabaseHandler = { + val channel = it.channel("channelId") + channel.subscribe() + }, + realtimeConfig = { + jwtToken = expectedAuthToken + } + ) + } + + @Test + fun testSendingBroadcasts() { + val message = buildJsonObject { + put("key", "value") + } + val event = "event" + createTestClient( + wsHandler = { + handleSubscribe("channelId") + val rMessage = this.receiveDeserialized() + assertEquals("realtime:channelId", rMessage.topic) + assertEquals("broadcast", rMessage.event) + assertEquals(message, rMessage.payload["payload"]?.jsonObject) + assertEquals(event, rMessage.payload["event"]?.jsonPrimitive?.content) + assertEquals("broadcast", rMessage.payload["type"]?.jsonPrimitive?.content) + }, + supabaseHandler = { + val channel = it.channel("channelId") + channel.subscribe(true) + channel.broadcast(event, message) + } + ) + } + + @Test + fun testSendingPresenceUnsubscribed() { + createTestClient( + wsHandler = { + handleSubscribe("channelId") + }, + supabaseHandler = { + val channel = it.channel("channelId") + channel.subscribe(true) + assertFailsWith { + channel.track(buildJsonObject { }) + } + } + ) + } + +} \ No newline at end of file diff --git a/Realtime/src/commonTest/kotlin/RealtimeTest.kt b/Realtime/src/commonTest/kotlin/RealtimeTest.kt new file mode 100644 index 00000000..689c465d --- /dev/null +++ b/Realtime/src/commonTest/kotlin/RealtimeTest.kt @@ -0,0 +1,50 @@ +import io.github.jan.supabase.realtime.Realtime +import io.github.jan.supabase.realtime.RealtimeMessage +import io.github.jan.supabase.realtime.realtime +import io.ktor.server.websocket.receiveDeserialized +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlin.test.Test +import kotlin.test.assertEquals + +class RealtimeTest { + + @Test + fun testRealtimeStatus() { + createTestClient( + wsHandler = { + //Does not matter for this test + }, + supabaseHandler = { + assertEquals(Realtime.Status.DISCONNECTED, it.realtime.status.value) + it.realtime.connect() + assertEquals(Realtime.Status.CONNECTED, it.realtime.status.value) + it.realtime.disconnect() + assertEquals(Realtime.Status.DISCONNECTED, it.realtime.status.value) + } + ) + } + + @Test + fun testSendingRealtimeMessages() { + val expectedMessage = RealtimeMessage( + topic = "realtimeTopic", + event = "realtimeEvent", + payload = buildJsonObject { + put("key", "value") + }, + ref = "realtimeRef" + ) + createTestClient( + wsHandler = { + val message = this.receiveDeserialized() + assertEquals(expectedMessage, message) + }, + supabaseHandler = { + it.realtime.connect() + it.realtime.send(expectedMessage) + } + ) + } + +} \ No newline at end of file diff --git a/Realtime/src/commonTest/kotlin/RealtimeTestUtils.kt b/Realtime/src/commonTest/kotlin/RealtimeTestUtils.kt new file mode 100644 index 00000000..00bd10a7 --- /dev/null +++ b/Realtime/src/commonTest/kotlin/RealtimeTestUtils.kt @@ -0,0 +1,11 @@ +import io.github.jan.supabase.realtime.RealtimeChannel.Companion.CHANNEL_EVENT_SYSTEM +import io.github.jan.supabase.realtime.RealtimeMessage +import io.ktor.server.websocket.DefaultWebSocketServerSession +import io.ktor.server.websocket.sendSerialized +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put + +suspend fun DefaultWebSocketServerSession.handleSubscribe(channelId: String) { + incoming.receive() + sendSerialized(RealtimeMessage("realtime:$channelId", CHANNEL_EVENT_SYSTEM, buildJsonObject { put("status", "ok") }, "")) +} \ No newline at end of file diff --git a/Realtime/src/commonTest/kotlin/RealtimeWSMock.kt b/Realtime/src/commonTest/kotlin/RealtimeWSMock.kt new file mode 100644 index 00000000..25ca2dc2 --- /dev/null +++ b/Realtime/src/commonTest/kotlin/RealtimeWSMock.kt @@ -0,0 +1,54 @@ +import io.github.jan.supabase.SupabaseClient +import io.github.jan.supabase.SupabaseClientBuilder +import io.github.jan.supabase.createSupabaseClient +import io.github.jan.supabase.logging.LogLevel +import io.github.jan.supabase.realtime.Realtime +import io.github.jan.supabase.supabaseJson +import io.ktor.client.plugins.websocket.webSocket +import io.ktor.serialization.kotlinx.KotlinxWebsocketSerializationConverter +import io.ktor.server.testing.ApplicationTestBuilder +import io.ktor.server.testing.testApplication +import io.ktor.server.websocket.DefaultWebSocketServerSession +import io.ktor.server.websocket.WebSockets +import io.ktor.server.websocket.webSocket + +fun ApplicationTestBuilder.configureServer( + handler: suspend DefaultWebSocketServerSession.() -> Unit +) { + install(WebSockets) { + contentConverter = KotlinxWebsocketSerializationConverter(supabaseJson) + } + routing { + webSocket("/", handler = handler) + } +} + +fun createTestClient( + wsHandler: suspend DefaultWebSocketServerSession.() -> Unit, + supabaseHandler: suspend (SupabaseClient) -> Unit, + realtimeConfig: Realtime.Config.() -> Unit = {}, + supabaseConfig: SupabaseClientBuilder.() -> Unit = {} +) { + testApplication { + configureServer(wsHandler) + val client = createClient { + install(io.ktor.client.plugins.websocket.WebSockets) { + contentConverter = KotlinxWebsocketSerializationConverter(supabaseJson) + } + } + client.webSocket("/") { + val supabase = createSupabaseClient("", "") { + defaultLogLevel = LogLevel.DEBUG + install(Realtime) { + websocketSessionProvider = { + this@webSocket + } + realtimeConfig() + } + supabaseConfig() + } + supabaseHandler(supabase) + supabase.close() + } + } +} \ No newline at end of file diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ff9c9d50..617ba4ef 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -74,6 +74,8 @@ ktor-client-websockets = { module = "io.ktor:ktor-client-websockets", version.re ktor-client-mock = { module = "io.ktor:ktor-client-mock", version.ref = "ktor" } ktor-json = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" } ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" } +ktor-server-host = { module = "io.ktor:ktor-server-test-host", version.ref = "ktor" } +ktor-server-websockets = { module = "io.ktor:ktor-server-websockets", version.ref = "ktor" } ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor" } kermit = { module = "co.touchlab:kermit", version.ref = "kermit" }