Skip to content

Commit

Permalink
Add ownership and provider to relay constraints updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Pururun committed Oct 30, 2023
1 parent 3ed7531 commit 748393e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.flow.stateIn
import net.mullvad.mullvadvpn.lib.ipc.Event
import net.mullvad.mullvadvpn.lib.ipc.Request
Expand All @@ -24,9 +25,13 @@ class RelayListListener(
messageHandler
.events<Event.NewRelayList>()
.map { it.relayList }
// This is added so that we always have a relay list. Otherwise sometimes there would
// not be a relay list since the fetching of a relay list would be done before the
// event stream is available.
.onStart { messageHandler.trySendRequest(Request.FetchRelayList) }
.stateIn(CoroutineScope(dispatcher), SharingStarted.Eagerly, null)

fun updateSelectedRelayLocation(value: GeographicLocationConstraint?) {
fun updateSelectedRelayLocation(value: GeographicLocationConstraint) {
messageHandler.trySendRequest(Request.SetRelayLocation(value))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ sealed class Request : Message.RequestMessage() {
@Parcelize data class SetEnableSplitTunneling(val enable: Boolean) : Request()

@Parcelize
data class SetRelayLocation(val relayLocation: GeographicLocationConstraint?) : Request()
data class SetRelayLocation(val relayLocation: GeographicLocationConstraint) : Request()

@Parcelize data class SetWireGuardMtu(val mtu: Int?) : Request()

Expand All @@ -102,6 +102,8 @@ sealed class Request : Message.RequestMessage() {

@Parcelize data class SetProviders(val providers: Constraint<Providers>) : Request()

@Parcelize data object FetchRelayList : Request()

companion object {
private const val MESSAGE_KEY = "request"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,30 @@
package net.mullvad.mullvadvpn.service.endpoint

import kotlin.properties.Delegates.observable
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.channels.actor
import kotlinx.coroutines.channels.trySendBlocking
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.launch
import net.mullvad.mullvadvpn.lib.ipc.Event
import net.mullvad.mullvadvpn.lib.ipc.Request
import net.mullvad.mullvadvpn.model.Constraint
import net.mullvad.mullvadvpn.model.GeographicLocationConstraint
import net.mullvad.mullvadvpn.model.LocationConstraint
import net.mullvad.mullvadvpn.model.RelayConstraintsUpdate
import net.mullvad.mullvadvpn.model.RelayList
import net.mullvad.mullvadvpn.model.RelaySettingsUpdate
import net.mullvad.mullvadvpn.model.WireguardConstraints
import net.mullvad.mullvadvpn.service.MullvadDaemon

class RelayListListener(endpoint: ServiceEndpoint) {

private val commandChannel = spawnActor()
class RelayListListener(
endpoint: ServiceEndpoint,
dispatcher: CoroutineDispatcher = Dispatchers.IO
) {
private val scope: CoroutineScope = CoroutineScope(SupervisorJob() + dispatcher)
private val daemon = endpoint.intermittentDaemon

private var selectedRelayLocation by
observable<GeographicLocationConstraint?>(null) { _, _, _ ->
commandChannel.trySendBlocking(Command.SetRelayLocation)
}
private var selectedWireguardConstraints by
observable<WireguardConstraints?>(null) { _, _, _ ->
commandChannel.trySendBlocking(Command.SetWireguardConstraints)
}

var relayList by
observable<RelayList?>(null) { _, _, relays ->
endpoint.sendEvent(Event.NewRelayList(relays))
Expand All @@ -46,18 +39,75 @@ class RelayListListener(endpoint: ServiceEndpoint) {
}
}

endpoint.dispatcher.registerHandler(Request.SetRelayLocation::class) { request ->
selectedRelayLocation = request.relayLocation
scope.launch {
endpoint.dispatcher.parsedMessages
.filterIsInstance<Request.SetRelayLocation>()
.collect { request ->
val update =
RelayConstraintsUpdate(
location =
Constraint.Only(LocationConstraint.Location(request.relayLocation)),
providers = null,
ownership = null,
wireguardConstraints = null
)
daemon.await().updateRelaySettings(RelaySettingsUpdate.Normal(update))
}
}

scope.launch {
endpoint.dispatcher.parsedMessages
.filterIsInstance<Request.SetWireguardConstraints>()
.collect { request ->
val update =
RelayConstraintsUpdate(
location = null,
providers = null,
ownership = null,
wireguardConstraints = request.wireguardConstraints
)
daemon.await().updateRelaySettings(RelaySettingsUpdate.Normal(update))
}
}

scope.launch {
endpoint.dispatcher.parsedMessages.filterIsInstance<Request.SetOwnership>().collect {
request ->
val update =
RelayConstraintsUpdate(
location = null,
providers = null,
ownership = request.ownership,
wireguardConstraints = null
)
daemon.await().updateRelaySettings(RelaySettingsUpdate.Normal(update))
}
}

scope.launch {
endpoint.dispatcher.parsedMessages.filterIsInstance<Request.SetProviders>().collect {
request ->
val update =
RelayConstraintsUpdate(
location = null,
providers = request.providers,
ownership = null,
wireguardConstraints = null
)
daemon.await().updateRelaySettings(RelaySettingsUpdate.Normal(update))
}
}

endpoint.dispatcher.registerHandler(Request.SetWireguardConstraints::class) { request ->
selectedWireguardConstraints = request.wireguardConstraints
scope.launch {
endpoint.dispatcher.parsedMessages.filterIsInstance<Request.FetchRelayList>().collect {
relayList = daemon.await().getRelayLocations()
}
}
}

fun onDestroy() {
commandChannel.close()
daemon.unregisterListener(this)
scope.cancel()
}

private fun setUpListener(daemon: MullvadDaemon) {
Expand All @@ -71,45 +121,4 @@ class RelayListListener(endpoint: ServiceEndpoint) {
}
}
}

private fun spawnActor() =
GlobalScope.actor<Command>(Dispatchers.Default, Channel.CONFLATED) {
try {
for (command in channel) {
when (command) {
Command.SetRelayLocation,
Command.SetWireguardConstraints -> updateRelayConstraints()
}
}
} catch (exception: ClosedReceiveChannelException) {
// Closed sender, so stop the actor
}
}

private suspend fun updateRelayConstraints() {
val location: Constraint<LocationConstraint> =
selectedRelayLocation?.let { location ->
Constraint.Only(LocationConstraint.Location(location))
} ?: Constraint.Any()
val wireguardConstraints: WireguardConstraints? = selectedWireguardConstraints

val update =
RelaySettingsUpdate.Normal(
RelayConstraintsUpdate(
location = location,
wireguardConstraints = wireguardConstraints,
ownership = Constraint.Any(),
providers = Constraint.Any()
)
)

daemon.await().updateRelaySettings(update)
}

companion object {
private enum class Command {
SetRelayLocation,
SetWireguardConstraints
}
}
}

0 comments on commit 748393e

Please sign in to comment.