diff --git a/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt b/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt index bfbee80f9efe..3b5da50d3374 100644 --- a/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt +++ b/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt @@ -12,6 +12,7 @@ import kotlinx.coroutines.flow.asSharedFlow import net.mullvad.mullvadvpn.compose.setContentWithTheme import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState import net.mullvad.mullvadvpn.compose.test.CIRCULAR_PROGRESS_INDICATOR +import net.mullvad.mullvadvpn.model.Constraint import net.mullvad.mullvadvpn.model.PortRange import net.mullvad.mullvadvpn.model.RelayEndpointData import net.mullvad.mullvadvpn.model.RelayList @@ -188,6 +189,6 @@ class SelectLocationScreenTest { arrayListOf(DUMMY_RELAY_COUNTRY_1, DUMMY_RELAY_COUNTRY_2), DUMMY_WIREGUARD_ENDPOINT_DATA ) - .toRelayCountries() + .toRelayCountries(ownership = Constraint.Any(), providers = Constraint.Any()) } } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt index 398e27820e82..bfd3f061d5ea 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/di/UiModule.kt @@ -18,10 +18,13 @@ import net.mullvad.mullvadvpn.repository.InAppNotificationController import net.mullvad.mullvadvpn.repository.PrivacyDisclaimerRepository import net.mullvad.mullvadvpn.repository.SettingsRepository import net.mullvad.mullvadvpn.ui.serviceconnection.MessageHandler +import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager import net.mullvad.mullvadvpn.ui.serviceconnection.SplitTunneling import net.mullvad.mullvadvpn.usecase.AccountExpiryNotificationUseCase import net.mullvad.mullvadvpn.usecase.NewDeviceNotificationUseCase +import net.mullvad.mullvadvpn.usecase.PortRangeUseCase +import net.mullvad.mullvadvpn.usecase.RelayListUseCase import net.mullvad.mullvadvpn.usecase.TunnelStateNotificationUseCase import net.mullvad.mullvadvpn.usecase.VersionNotificationUseCase import net.mullvad.mullvadvpn.util.ChangelogDataProvider @@ -88,25 +91,29 @@ val uiModule = module { single { TunnelStateNotificationUseCase(get()) } single { VersionNotificationUseCase(get(), BuildConfig.ENABLE_IN_APP_VERSION_NOTIFICATIONS) } single { NewDeviceNotificationUseCase(get()) } + single { PortRangeUseCase(get()) } + single { RelayListUseCase(get(), get()) } single { InAppNotificationController(get(), get(), get(), get(), MainScope()) } single { ChangelogDataProvider(get()) } + single { RelayListListener(get()) } + // View models viewModel { AccountViewModel(get(), get(), get()) } viewModel { ChangelogViewModel(get(), BuildConfig.VERSION_CODE, BuildConfig.ALWAYS_SHOW_CHANGELOG) } - viewModel { ConnectViewModel(get(), get(), get(), get(), get()) } + viewModel { ConnectViewModel(get(), get(), get(), get(), get(), get()) } viewModel { DeviceListViewModel(get(), get()) } viewModel { DeviceRevokedViewModel(get(), get()) } viewModel { LoginViewModel(get(), get(), get()) } viewModel { PrivacyDisclaimerViewModel(get()) } - viewModel { SelectLocationViewModel(get()) } + viewModel { SelectLocationViewModel(get(), get()) } viewModel { SettingsViewModel(get(), get()) } viewModel { VoucherDialogViewModel(get(), get()) } - viewModel { VpnSettingsViewModel(get(), get(), get(), get()) } + viewModel { VpnSettingsViewModel(get(), get(), get(), get(), get()) } viewModel { WelcomeViewModel(get(), get(), get()) } viewModel { ReportProblemViewModel(get()) } viewModel { ViewLogsViewModel(get()) } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Provider.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Provider.kt new file mode 100644 index 000000000000..c10397670050 --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/Provider.kt @@ -0,0 +1,3 @@ +package net.mullvad.mullvadvpn.relaylist + +data class Provider(val name: String, val mullvadOwned: Boolean) diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt new file mode 100644 index 000000000000..c74c74ba437a --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayList.kt @@ -0,0 +1,3 @@ +package net.mullvad.mullvadvpn.relaylist + +data class RelayList(val country: List, val selectedItem: RelayItem?) diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayListExtensions.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayListExtensions.kt index 39618f96034e..71d4701e41dc 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayListExtensions.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/relaylist/RelayListExtensions.kt @@ -2,14 +2,20 @@ package net.mullvad.mullvadvpn.relaylist import net.mullvad.mullvadvpn.model.Constraint import net.mullvad.mullvadvpn.model.GeographicLocationConstraint +import net.mullvad.mullvadvpn.model.Ownership +import net.mullvad.mullvadvpn.model.Providers +import net.mullvad.mullvadvpn.model.Relay as DaemonRelay import net.mullvad.mullvadvpn.model.RelayList /** * Convert from a model.RelayList to list of relaylist.RelayCountry Non-wiregaurd relays are - * filtered out So are also cities that only contains non-wireguard relays Countries, cities and - * relays are ordered by name + * filtered out and also relays that do not fit the ownership and provider list So are also cities + * that only contains non-wireguard relays Countries, cities and relays are ordered by name */ -fun RelayList.toRelayCountries(): List { +fun RelayList.toRelayCountries( + ownership: Constraint, + providers: Constraint +): List { val relayCountries = this.countries .map { country -> @@ -27,7 +33,8 @@ fun RelayList.toRelayCountries(): List { relays = relays ) - val validCityRelays = city.relays.filter { relay -> relay.isWireguardRelay } + val validCityRelays = + city.relays.filterValidRelays(ownership = ownership, providers = providers) for (relay in validCityRelays) { relays.add( @@ -170,6 +177,28 @@ fun List.filterOnSearchTerm( } } +private fun List.filterValidRelays( + ownership: Constraint, + providers: Constraint +): List = + filter { it.isWireguardRelay } + .filter { + when (ownership) { + is Constraint.Any -> true + is Constraint.Only -> + when (ownership.value) { + Ownership.MullvadOwned -> it.owned + Ownership.Rented -> !it.owned + } + } + } + .filter { relay -> + when (providers) { + is Constraint.Any -> true + is Constraint.Only -> providers.value.providers.contains(relay.provider) + } + } + /** Expand the parent(s), if any, for the current selected item */ private fun List.expandItemForSelection( selectedItem: RelayItem? diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt index 5c6e765b4e79..0a1767624c64 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt @@ -1,135 +1,52 @@ package net.mullvad.mullvadvpn.ui.serviceconnection -import android.os.Messenger -import net.mullvad.mullvadvpn.lib.common.util.toGeographicLocationConstraint +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onStart +import kotlinx.coroutines.flow.stateIn import net.mullvad.mullvadvpn.lib.ipc.Event -import net.mullvad.mullvadvpn.lib.ipc.EventDispatcher import net.mullvad.mullvadvpn.lib.ipc.Request import net.mullvad.mullvadvpn.model.Constraint import net.mullvad.mullvadvpn.model.GeographicLocationConstraint -import net.mullvad.mullvadvpn.model.PortRange -import net.mullvad.mullvadvpn.model.RelayConstraints -import net.mullvad.mullvadvpn.model.RelaySettings +import net.mullvad.mullvadvpn.model.Ownership +import net.mullvad.mullvadvpn.model.Providers +import net.mullvad.mullvadvpn.model.RelayList import net.mullvad.mullvadvpn.model.WireguardConstraints -import net.mullvad.mullvadvpn.relaylist.RelayCountry -import net.mullvad.mullvadvpn.relaylist.RelayItem -import net.mullvad.mullvadvpn.relaylist.findItemForLocation -import net.mullvad.mullvadvpn.relaylist.toRelayCountries +import net.mullvad.mullvadvpn.model.WireguardEndpointData class RelayListListener( - private val connection: Messenger, - eventDispatcher: EventDispatcher, - private val settingsListener: SettingsListener + private val messageHandler: MessageHandler, + dispatcher: CoroutineDispatcher = Dispatchers.IO ) { - private var relayCountries: List? = null - private var relaySettings: RelaySettings? = null - private var portRanges: List = emptyList() - - var selectedRelayItem: RelayItem? = null - private set + val relayListEvents: StateFlow = + messageHandler + .events() + .map { it.relayList ?: defaultRelayList() } + // This is added so that we always have a relay list. Otherwise sometimes there would + // not be a relay list since the fetching of a relay list would be done before the + // event stream is available. + .onStart { messageHandler.trySendRequest(Request.FetchRelayList) } + .stateIn(CoroutineScope(dispatcher), SharingStarted.Eagerly, defaultRelayList()) fun updateSelectedRelayLocation(value: GeographicLocationConstraint) { - connection.send(Request.SetRelayLocation(value).message) + messageHandler.trySendRequest(Request.SetRelayLocation(value)) } fun updateSelectedWireguardConstraints(value: WireguardConstraints) { - connection.send(Request.SetWireguardConstraints(value).message) - } - - var onRelayCountriesChange: ((List, RelayItem?) -> Unit)? = null - set(value) { - field = value - - synchronized(this) { - val relayCountries = this.relayCountries - - if (relayCountries != null) { - value?.invoke(relayCountries, selectedRelayItem) - } - } - } - - var onPortRangesChange: ((List) -> Unit)? = null - set(value) { - field = value - - synchronized(this) { value?.invoke(portRanges) } - } - - init { - eventDispatcher.registerHandler(Event.NewRelayList::class) { event -> - event.relayList?.let { relayLocations -> - relayListChanged(relayLocations.toRelayCountries()) - portRangesChanged(relayLocations.wireguardEndpointData.portRanges) - } - } - - settingsListener.relaySettingsNotifier.subscribe(this) { newRelaySettings -> - relaySettingsChanged(newRelaySettings) - } - } - - fun onDestroy() { - settingsListener.relaySettingsNotifier.unsubscribe(this) - onRelayCountriesChange = null - } - - private fun relaySettingsChanged(newRelaySettings: RelaySettings?) { - synchronized(this) { - val relayCountries = this.relayCountries - val portRanges = this.portRanges - - relaySettings = - newRelaySettings - ?: RelaySettings.Normal( - RelayConstraints( - location = Constraint.Any(), - ownership = Constraint.Any(), - wireguardConstraints = WireguardConstraints(Constraint.Any()), - providers = Constraint.Any() - ) - ) - - if (relayCountries != null) { - relayListChanged(relayCountries) - } - portRangesChanged(portRanges) - } + messageHandler.trySendRequest(Request.SetWireguardConstraints(value)) } - private fun relayListChanged(newRelayCountries: List) { - synchronized(this) { - relayCountries = newRelayCountries - selectedRelayItem = findSelectedRelayItem() - - onRelayCountriesChange?.invoke(newRelayCountries, selectedRelayItem) - } + fun updateSelectedOwnershipFilter(value: Constraint) { + messageHandler.trySendRequest(Request.SetOwnership(value)) } - private fun portRangesChanged(newPortRanges: List) { - synchronized(this) { - portRanges = newPortRanges - - onPortRangesChange?.invoke(portRanges) - } + fun updateSelectedProvidersFilter(value: Constraint) { + messageHandler.trySendRequest(Request.SetProviders(value)) } - private fun findSelectedRelayItem(): RelayItem? { - val relaySettings = this.relaySettings - - when (relaySettings) { - is RelaySettings.CustomTunnelEndpoint -> return null - is RelaySettings.Normal -> { - val location = relaySettings.relayConstraints.location - return relayCountries?.findItemForLocation( - location.toGeographicLocationConstraint() - ) - } - else -> { - /* NOOP */ - } - } - - return null - } + private fun defaultRelayList() = RelayList(ArrayList(), WireguardEndpointData(ArrayList())) } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionContainer.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionContainer.kt index 7d34f2a96f2c..ca156bed66e8 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionContainer.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionContainer.kt @@ -37,7 +37,6 @@ class ServiceConnectionContainer( val appVersionInfoCache = AppVersionInfoCache(dispatcher, settingsListener) val customDns = CustomDns(connection) - var relayListListener = RelayListListener(connection, dispatcher, settingsListener) private var listenerId: Int? = null @@ -68,7 +67,6 @@ class ServiceConnectionContainer( voucherRedeemer.onDestroy() appVersionInfoCache.onDestroy() - relayListListener.onDestroy() } private fun registerListener(connection: Messenger) { diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManager.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManager.kt index d840b934919c..556d69ecfe33 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManager.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManager.kt @@ -12,6 +12,7 @@ import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.map import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointConfiguration import net.mullvad.mullvadvpn.lib.endpoint.BuildConfig import net.mullvad.mullvadvpn.lib.endpoint.putApiEndpointConfigurationExtra @@ -93,7 +94,7 @@ class ServiceConnectionManager(private val context: Context) : MessageHandler { } override fun events(klass: KClass): Flow { - return events.filterIsInstance(klass) + return events.map { it }.filterIsInstance(klass) } override fun trySendRequest(request: Request): Boolean { diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManagerExtensions.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManagerExtensions.kt index c4b3d100bd30..3232d20cbd9e 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManagerExtensions.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/ServiceConnectionManagerExtensions.kt @@ -14,9 +14,6 @@ fun ServiceConnectionManager.deviceDataSource() = fun ServiceConnectionManager.customDns() = this.connectionState.value.readyContainer()?.customDns -fun ServiceConnectionManager.relayListListener() = - this.connectionState.value.readyContainer()?.relayListListener - fun ServiceConnectionManager.settingsListener() = this.connectionState.value.readyContainer()?.settingsListener diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PortRangeUseCase.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PortRangeUseCase.kt new file mode 100644 index 000000000000..2b104cda3985 --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/PortRangeUseCase.kt @@ -0,0 +1,14 @@ +package net.mullvad.mullvadvpn.usecase + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.distinctUntilChanged +import kotlinx.coroutines.flow.map +import net.mullvad.mullvadvpn.model.PortRange +import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener + +class PortRangeUseCase(private val relayListListener: RelayListListener) { + fun portRanges(): Flow> = + relayListListener.relayListEvents + .map { it?.wireguardEndpointData?.portRanges ?: emptyList() } + .distinctUntilChanged() +} diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListFilterUseCase.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListFilterUseCase.kt new file mode 100644 index 000000000000..a26f302f9ca4 --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListFilterUseCase.kt @@ -0,0 +1,45 @@ +package net.mullvad.mullvadvpn.usecase + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import net.mullvad.mullvadvpn.model.Constraint +import net.mullvad.mullvadvpn.model.Ownership +import net.mullvad.mullvadvpn.model.Providers +import net.mullvad.mullvadvpn.model.RelayListCity +import net.mullvad.mullvadvpn.model.RelayListCountry +import net.mullvad.mullvadvpn.relaylist.Provider +import net.mullvad.mullvadvpn.repository.SettingsRepository +import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener + +class RelayListFilterUseCase( + private val relayListListener: RelayListListener, + private val settingsRepository: SettingsRepository +) { + fun updateOwnershipFilter(ownership: Constraint) { + relayListListener.updateSelectedOwnershipFilter(ownership) + } + + fun updateProviderFilter(providers: Constraint) { + relayListListener.updateSelectedProvidersFilter(providers) + } + + fun selectedOwnership(): Flow> = + settingsRepository.settingsUpdates.map { settings -> + settings?.relaySettings?.relayConstraints()?.ownership ?: Constraint.Any() + } + + fun selectedProviders(): Flow> = + settingsRepository.settingsUpdates.map { settings -> + settings?.relaySettings?.relayConstraints()?.providers ?: Constraint.Any() + } + + fun availableProviders(): Flow> = + relayListListener.relayListEvents.map { relayList -> + relayList.countries + .flatMap(RelayListCountry::cities) + .flatMap(RelayListCity::relays) + .filter { relay -> relay.isWireguardRelay } + .map { relay -> Provider(relay.provider, relay.owned) } + .distinct() + } +} diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt new file mode 100644 index 000000000000..0bfe1d038c81 --- /dev/null +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt @@ -0,0 +1,57 @@ +package net.mullvad.mullvadvpn.usecase + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.map +import net.mullvad.mullvadvpn.lib.common.util.toGeographicLocationConstraint +import net.mullvad.mullvadvpn.model.Constraint +import net.mullvad.mullvadvpn.model.GeographicLocationConstraint +import net.mullvad.mullvadvpn.model.RelaySettings +import net.mullvad.mullvadvpn.model.WireguardConstraints +import net.mullvad.mullvadvpn.relaylist.RelayCountry +import net.mullvad.mullvadvpn.relaylist.RelayItem +import net.mullvad.mullvadvpn.relaylist.RelayList +import net.mullvad.mullvadvpn.relaylist.findItemForLocation +import net.mullvad.mullvadvpn.relaylist.toRelayCountries +import net.mullvad.mullvadvpn.repository.SettingsRepository +import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener + +class RelayListUseCase( + private val relayListListener: RelayListListener, + private val settingsRepository: SettingsRepository +) { + + fun updateSelectedRelayLocation(value: GeographicLocationConstraint) { + relayListListener.updateSelectedRelayLocation(value) + } + + fun updateSelectedWireguardConstraints(value: WireguardConstraints) { + relayListListener.updateSelectedWireguardConstraints(value) + } + + fun relayListWithSelection(): Flow = + combine(relayListListener.relayListEvents, settingsRepository.settingsUpdates) { + relayList, + settings -> + val ownership = + settings?.relaySettings?.relayConstraints()?.ownership ?: Constraint.Any() + val providers = + settings?.relaySettings?.relayConstraints()?.providers ?: Constraint.Any() + val relayCountries = + relayList.toRelayCountries(ownership = ownership, providers = providers) + val selectedItem = + relayCountries.findSelectedRelayItem( + relaySettings = settings?.relaySettings, + ) + RelayList(relayCountries, selectedItem) + } + + fun selectedRelayItem(): Flow = relayListWithSelection().map { it.selectedItem } + + private fun List.findSelectedRelayItem( + relaySettings: RelaySettings?, + ): RelayItem? { + val location = relaySettings?.relayConstraints()?.location + return location?.let { this.findItemForLocation(location.toGeographicLocationConstraint()) } + } +} diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt index 8a4f087d64b0..9f6cec8391f1 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModel.kt @@ -27,13 +27,13 @@ import net.mullvad.mullvadvpn.repository.DeviceRepository import net.mullvad.mullvadvpn.repository.InAppNotificationController import net.mullvad.mullvadvpn.ui.serviceconnection.ConnectionProxy import net.mullvad.mullvadvpn.ui.serviceconnection.LocationInfoCache -import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionContainer import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState import net.mullvad.mullvadvpn.ui.serviceconnection.authTokenCache import net.mullvad.mullvadvpn.ui.serviceconnection.connectionProxy import net.mullvad.mullvadvpn.usecase.NewDeviceNotificationUseCase +import net.mullvad.mullvadvpn.usecase.RelayListUseCase import net.mullvad.mullvadvpn.util.callbackFlowFromNotifier import net.mullvad.mullvadvpn.util.combine import net.mullvad.mullvadvpn.util.daysFromNow @@ -48,7 +48,8 @@ class ConnectViewModel( accountRepository: AccountRepository, private val deviceRepository: DeviceRepository, private val inAppNotificationController: InAppNotificationController, - private val newDeviceNotificationUseCase: NewDeviceNotificationUseCase + private val newDeviceNotificationUseCase: NewDeviceNotificationUseCase, + private val relayListUseCase: RelayListUseCase ) : ViewModel() { private val _uiSideEffect = MutableSharedFlow(extraBufferCapacity = 1) val uiSideEffect = _uiSideEffect.asSharedFlow() @@ -71,7 +72,7 @@ class ConnectViewModel( .flatMapLatest { serviceConnection -> combine( serviceConnection.locationInfoCache.locationCallbackFlow(), - serviceConnection.relayListListener.relayListCallbackFlow(), + relayListUseCase.selectedRelayItem(), inAppNotificationController.notifications, serviceConnection.connectionProxy.tunnelUiStateFlow(), serviceConnection.connectionProxy.tunnelRealStateFlow(), @@ -137,11 +138,6 @@ class ConnectViewModel( awaitClose { onNewLocation = null } } - private fun RelayListListener.relayListCallbackFlow() = callbackFlow { - onRelayCountriesChange = { _, item -> this.trySend(item) } - awaitClose { onRelayCountriesChange = null } - } - private fun ConnectionProxy.tunnelUiStateFlow(): Flow = callbackFlowFromNotifier(this.onUiStateChange) diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt index 0d8f753d8ff7..5e95674e0a6a 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt @@ -2,60 +2,39 @@ package net.mullvad.mullvadvpn.viewmodel import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope -import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.asSharedFlow -import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.flow.combine -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.flow.flatMapLatest -import kotlinx.coroutines.flow.flowOf -import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.launch import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState import net.mullvad.mullvadvpn.relaylist.RelayItem import net.mullvad.mullvadvpn.relaylist.filterOnSearchTerm -import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState import net.mullvad.mullvadvpn.ui.serviceconnection.connectionProxy -import net.mullvad.mullvadvpn.ui.serviceconnection.relayListListener +import net.mullvad.mullvadvpn.usecase.RelayListUseCase -class SelectLocationViewModel(private val serviceConnectionManager: ServiceConnectionManager) : - ViewModel() { +class SelectLocationViewModel( + private val serviceConnectionManager: ServiceConnectionManager, + private val relayListUseCase: RelayListUseCase +) : ViewModel() { private val _closeAction = MutableSharedFlow() private val _enterTransitionEndAction = MutableSharedFlow() private val _searchTerm = MutableStateFlow(EMPTY_SEARCH_TERM) val uiState = - serviceConnectionManager.connectionState - .flatMapLatest { state -> - if (state is ServiceConnectionState.ConnectedReady) { - flowOf(state.container) - } else { - emptyFlow() - } - } - .flatMapLatest { serviceConnection -> - combine(serviceConnection.relayListListener.relayListCallbackFlow(), _searchTerm) { - (relayCountries, relayItem), - searchTerm -> - Triple( - relayCountries.filterOnSearchTerm(searchTerm, relayItem), - relayItem, - searchTerm - ) - } - } - .map { (relayCountries, relayItem, searchTerm) -> - if (searchTerm.isNotEmpty() && relayCountries.isEmpty()) { + combine(relayListUseCase.relayListWithSelection(), _searchTerm) { + (relayCountries, relayItem), + searchTerm -> + val filteredRelayCountries = + relayCountries.filterOnSearchTerm(searchTerm, relayItem) + if (searchTerm.isNotEmpty() && filteredRelayCountries.isEmpty()) { SelectLocationUiState.NoSearchResultFound(searchTerm = searchTerm) } else { SelectLocationUiState.ShowData( - countries = relayCountries, + countries = filteredRelayCountries, selectedRelay = relayItem ) } @@ -72,9 +51,7 @@ class SelectLocationViewModel(private val serviceConnectionManager: ServiceConne val enterTransitionEndAction = _enterTransitionEndAction.asSharedFlow() fun selectRelay(relayItem: RelayItem) { - serviceConnectionManager - .relayListListener() - ?.updateSelectedRelayLocation(relayItem.location) + relayListUseCase.updateSelectedRelayLocation(relayItem.location) serviceConnectionManager.connectionProxy()?.connect() viewModelScope.launch { _closeAction.emit(Unit) } } @@ -87,11 +64,6 @@ class SelectLocationViewModel(private val serviceConnectionManager: ServiceConne viewModelScope.launch { _searchTerm.emit(searchTerm) } } - private fun RelayListListener.relayListCallbackFlow() = callbackFlow { - onRelayCountriesChange = { list, item -> this.trySend(list to item) } - awaitClose { onRelayCountriesChange = null } - } - companion object { private const val EMPTY_SEARCH_TERM = "" } diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModel.kt index 94abf1da9060..dfae3df53956 100644 --- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModel.kt +++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModel.kt @@ -7,16 +7,11 @@ import androidx.lifecycle.viewModelScope import java.net.InetAddress import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.asSharedFlow -import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.flow.combine -import kotlinx.coroutines.flow.emptyFlow -import kotlinx.coroutines.flow.flatMapLatest -import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.flow.update @@ -35,10 +30,8 @@ import net.mullvad.mullvadvpn.model.Settings import net.mullvad.mullvadvpn.model.Udp2TcpObfuscationSettings import net.mullvad.mullvadvpn.model.WireguardConstraints import net.mullvad.mullvadvpn.repository.SettingsRepository -import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState -import net.mullvad.mullvadvpn.ui.serviceconnection.relayListListener +import net.mullvad.mullvadvpn.usecase.PortRangeUseCase +import net.mullvad.mullvadvpn.usecase.RelayListUseCase import net.mullvad.mullvadvpn.util.isValidMtu import org.apache.commons.validator.routines.InetAddressValidator @@ -46,7 +39,8 @@ class VpnSettingsViewModel( private val repository: SettingsRepository, private val inetAddressValidator: InetAddressValidator, private val resources: Resources, - private val serviceConnectionManager: ServiceConnectionManager, + portRangeUseCase: PortRangeUseCase, + private val relayListUseCase: RelayListUseCase, private val dispatcher: CoroutineDispatcher = Dispatchers.IO ) : ViewModel() { @@ -57,38 +51,26 @@ class VpnSettingsViewModel( private val dialogState = MutableStateFlow(null) private val vmState = - serviceConnectionManager.connectionState - .flatMapLatest { state -> - if (state is ServiceConnectionState.ConnectedReady) { - flowOf(state.container) - } else { - emptyFlow() - } - } - .flatMapLatest { serviceConnection -> - combine( - repository.settingsUpdates, - serviceConnection.relayListListener.portRangesCallbackFlow(), - dialogState - ) { settings, portRanges, dialogState -> - VpnSettingsViewModelState( - mtuValue = settings?.mtuString() ?: "", - isAutoConnectEnabled = settings?.autoConnect ?: false, - isLocalNetworkSharingEnabled = settings?.allowLan ?: false, - isCustomDnsEnabled = settings?.isCustomDnsEnabled() ?: false, - customDnsList = settings?.addresses()?.asStringAddressList() ?: listOf(), - contentBlockersOptions = settings?.contentBlockersSettings() - ?: DefaultDnsOptions(), - isAllowLanEnabled = settings?.allowLan ?: false, - selectedObfuscation = settings?.selectedObfuscationSettings() - ?: SelectedObfuscation.Off, - dialogState = dialogState, - quantumResistant = settings?.quantumResistant() - ?: QuantumResistantState.Off, - selectedWireguardPort = settings?.getWireguardPort() ?: Constraint.Any(), - availablePortRanges = portRanges - ) - } + combine(repository.settingsUpdates, portRangeUseCase.portRanges(), dialogState) { + settings, + portRanges, + dialogState -> + VpnSettingsViewModelState( + mtuValue = settings?.mtuString() ?: "", + isAutoConnectEnabled = settings?.autoConnect ?: false, + isLocalNetworkSharingEnabled = settings?.allowLan ?: false, + isCustomDnsEnabled = settings?.isCustomDnsEnabled() ?: false, + customDnsList = settings?.addresses()?.asStringAddressList() ?: listOf(), + contentBlockersOptions = settings?.contentBlockersSettings() + ?: DefaultDnsOptions(), + isAllowLanEnabled = settings?.allowLan ?: false, + selectedObfuscation = settings?.selectedObfuscationSettings() + ?: SelectedObfuscation.Off, + dialogState = dialogState, + quantumResistant = settings?.quantumResistant() ?: QuantumResistantState.Off, + selectedWireguardPort = settings?.getWireguardPort() ?: Constraint.Any(), + availablePortRanges = portRanges + ) } .stateIn( viewModelScope, @@ -351,11 +333,7 @@ class VpnSettingsViewModel( } fun onWireguardPortSelected(port: Constraint) { - viewModelScope.launch(dispatcher) { - serviceConnectionManager - .relayListListener() - ?.updateSelectedWireguardConstraints(WireguardConstraints(port = port)) - } + relayListUseCase.updateSelectedWireguardConstraints(WireguardConstraints(port = port)) hideDialog() } @@ -423,11 +401,6 @@ class VpnSettingsViewModel( private fun Settings.selectedObfuscationSettings() = obfuscationSettings.selectedObfuscation - private fun RelayListListener.portRangesCallbackFlow() = callbackFlow { - onPortRangesChange = { portRanges -> this.trySend(portRanges) } - awaitClose { onPortRangesChange = null } - } - private fun Settings.getWireguardPort() = when (relaySettings) { RelaySettings.CustomTunnelEndpoint -> Constraint.Any() diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt index 5839e575c19e..1b2e262b3dd4 100644 --- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/ConnectViewModelTest.kt @@ -34,12 +34,12 @@ import net.mullvad.mullvadvpn.ui.serviceconnection.AppVersionInfoCache import net.mullvad.mullvadvpn.ui.serviceconnection.AuthTokenCache import net.mullvad.mullvadvpn.ui.serviceconnection.ConnectionProxy import net.mullvad.mullvadvpn.ui.serviceconnection.LocationInfoCache -import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionContainer import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState import net.mullvad.mullvadvpn.ui.serviceconnection.authTokenCache import net.mullvad.mullvadvpn.ui.serviceconnection.connectionProxy +import net.mullvad.mullvadvpn.usecase.RelayListUseCase import net.mullvad.mullvadvpn.util.appVersionCallbackFlow import net.mullvad.talpid.tunnel.ErrorState import net.mullvad.talpid.tunnel.ErrorStateCause @@ -73,7 +73,6 @@ class ConnectViewModelTest { // Service connections private val mockServiceConnectionContainer: ServiceConnectionContainer = mockk() private val mockLocationInfoCache: LocationInfoCache = mockk(relaxUnitFun = true) - private val mockRelayListListener: RelayListListener = mockk(relaxUnitFun = true) private lateinit var mockAppVersionInfoCache: AppVersionInfoCache private val mockConnectionProxy: ConnectionProxy = mockk() private val mockLocation: GeoIpLocation = mockk(relaxed = true) @@ -87,14 +86,19 @@ class ConnectViewModelTest { // In App Notifications private val mockInAppNotificationController: InAppNotificationController = mockk() + // Relay list use case + private val mockRelayListUseCase: RelayListUseCase = mockk() + // Captures private val locationSlot = slot<((GeoIpLocation?) -> Unit)>() - private val relaySlot = slot<(List, RelayItem?) -> Unit>() // Event notifiers private val eventNotifierTunnelUiState = EventNotifier(TunnelState.Disconnected) private val eventNotifierTunnelRealState = EventNotifier(TunnelState.Disconnected) + // Flows + private val selectedRelayFlow = MutableStateFlow(null) + @Before fun setup() { mockkStatic(CACHE_EXTENSION_CLASS) @@ -107,7 +111,6 @@ class ConnectViewModelTest { every { mockServiceConnectionManager.connectionState } returns serviceConnectionState every { mockServiceConnectionContainer.locationInfoCache } returns mockLocationInfoCache - every { mockServiceConnectionContainer.relayListListener } returns mockRelayListListener every { mockServiceConnectionContainer.appVersionInfoCache } returns mockAppVersionInfoCache every { mockServiceConnectionContainer.connectionProxy } returns mockConnectionProxy @@ -124,15 +127,18 @@ class ConnectViewModelTest { // Listeners every { mockLocationInfoCache.onNewLocation = capture(locationSlot) } answers {} - every { mockRelayListListener.onRelayCountriesChange = capture(relaySlot) } answers {} every { mockAppVersionInfoCache.onUpdate = any() } answers {} + // Flows + every { mockRelayListUseCase.selectedRelayItem() } returns selectedRelayFlow + viewModel = ConnectViewModel( serviceConnectionManager = mockServiceConnectionManager, accountRepository = mockAccountRepository, deviceRepository = mockDeviceRepository, inAppNotificationController = mockInAppNotificationController, + relayListUseCase = mockRelayListUseCase, newDeviceNotificationUseCase = mockk() ) } @@ -156,7 +162,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(mockLocation) - relaySlot.captured.invoke(mockk(), mockk()) viewModel.toggleTunnelInfoExpansion() val result = awaitItem() assertTrue(result.isTunnelInfoExpanded) @@ -173,7 +178,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(mockLocation) - relaySlot.captured.invoke(mockk(), mockk()) eventNotifierTunnelRealState.notify(tunnelRealStateTestItem) val result = awaitItem() assertEquals(tunnelRealStateTestItem, result.tunnelRealState) @@ -190,7 +194,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(mockLocation) - relaySlot.captured.invoke(mockk(), mockk()) eventNotifierTunnelUiState.notify(tunnelUiStateTestItem) val result = awaitItem() assertEquals(tunnelUiStateTestItem, result.tunnelUiState) @@ -202,13 +205,13 @@ class ConnectViewModelTest { runTest(testCoroutineRule.testDispatcher) { val relayTestItem = RelayCountry(name = "Name", code = "Code", expanded = false, cities = emptyList()) + selectedRelayFlow.value = relayTestItem viewModel.uiState.test { assertEquals(ConnectUiState.INITIAL, awaitItem()) serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(mockLocation) - relaySlot.captured.invoke(mockk(), relayTestItem) val result = awaitItem() assertEquals(relayTestItem, result.relayLocation) } @@ -231,7 +234,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(locationTestItem) - relaySlot.captured.invoke(mockk(), mockk()) val result = awaitItem() assertEquals(locationTestItem, result.location) } @@ -249,7 +251,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(locationTestItem) - relaySlot.captured.invoke(mockk(), mockk()) expectNoEvents() val result = awaitItem() assertEquals(locationTestItem, result.location) @@ -308,7 +309,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(mockLocation) - relaySlot.captured.invoke(mockk(), mockk()) eventNotifierTunnelUiState.notify(tunnelUiState) val result = awaitItem() assertEquals(expectedConnectNotificationState, result.inAppNotification) @@ -347,7 +347,6 @@ class ConnectViewModelTest { serviceConnectionState.value = ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) locationSlot.captured.invoke(mockLocation) - relaySlot.captured.invoke(mockk(), mockk()) eventNotifierTunnelRealState.notify(tunnelRealStateTestItem) awaitItem() } diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt index 3cadfe575fb1..44be67fa648f 100644 --- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt @@ -5,7 +5,6 @@ import app.cash.turbine.test import io.mockk.every import io.mockk.mockk import io.mockk.mockkStatic -import io.mockk.slot import io.mockk.unmockkAll import io.mockk.verify import kotlin.test.assertEquals @@ -19,14 +18,12 @@ import net.mullvad.mullvadvpn.lib.common.test.assertLists import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.relaylist.RelayCountry import net.mullvad.mullvadvpn.relaylist.RelayItem +import net.mullvad.mullvadvpn.relaylist.RelayList import net.mullvad.mullvadvpn.relaylist.filterOnSearchTerm import net.mullvad.mullvadvpn.ui.serviceconnection.ConnectionProxy -import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionContainer import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState import net.mullvad.mullvadvpn.ui.serviceconnection.connectionProxy -import net.mullvad.mullvadvpn.ui.serviceconnection.relayListListener +import net.mullvad.mullvadvpn.usecase.RelayListUseCase import org.junit.After import org.junit.Before import org.junit.Rule @@ -38,28 +35,18 @@ class SelectLocationViewModelTest { private val mockServiceConnectionManager: ServiceConnectionManager = mockk() private lateinit var viewModel: SelectLocationViewModel - // Service connections - private val mockServiceConnectionContainer: ServiceConnectionContainer = mockk() - private val mockRelayListListener: RelayListListener = mockk(relaxUnitFun = true) + private val relayListWithSelectionFlow = MutableStateFlow(RelayList(emptyList(), null)) - // Captures - private val relaySlot = slot<(List, RelayItem?) -> Unit>() - - private val serviceConnectionState = - MutableStateFlow(ServiceConnectionState.Disconnected) + private val mockRelayListUseCase: RelayListUseCase = mockk() @Before fun setup() { - every { mockServiceConnectionManager.connectionState } returns serviceConnectionState - every { mockServiceConnectionContainer.relayListListener } returns mockRelayListListener - - every { mockRelayListListener.onRelayCountriesChange = capture(relaySlot) } answers {} - every { mockRelayListListener.onRelayCountriesChange = null } answers {} + every { mockRelayListUseCase.relayListWithSelection() } returns relayListWithSelectionFlow mockkStatic(SERVICE_CONNECTION_MANAGER_EXTENSIONS) mockkStatic(RELAY_LIST_EXTENSIONS) - viewModel = SelectLocationViewModel(mockServiceConnectionManager) + viewModel = SelectLocationViewModel(mockServiceConnectionManager, mockRelayListUseCase) } @After @@ -70,7 +57,7 @@ class SelectLocationViewModelTest { @Test fun testInitialState() = runTest { - viewModel.uiState.test { assertEquals(SelectLocationUiState.Loading, awaitItem()) } + assertEquals(SelectLocationUiState.Loading, viewModel.uiState.value) } @Test @@ -79,14 +66,10 @@ class SelectLocationViewModelTest { val mockCountries = listOf(mockk(), mockk()) val selectedRelay: RelayItem = mockk() every { mockCountries.filterOnSearchTerm(any(), selectedRelay) } returns mockCountries + relayListWithSelectionFlow.value = RelayList(mockCountries, selectedRelay) // Act, Assert viewModel.uiState.test { - serviceConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - relaySlot.captured.invoke(mockCountries, selectedRelay) - - assertEquals(SelectLocationUiState.Loading, awaitItem()) val actualState = awaitItem() assertIs(actualState) assertLists(mockCountries, actualState.countries) @@ -100,14 +83,10 @@ class SelectLocationViewModelTest { val mockCountries = listOf(mockk(), mockk()) val selectedRelay: RelayItem? = null every { mockCountries.filterOnSearchTerm(any(), selectedRelay) } returns mockCountries + relayListWithSelectionFlow.value = RelayList(mockCountries, selectedRelay) // Act, Assert viewModel.uiState.test { - serviceConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - relaySlot.captured.invoke(mockCountries, selectedRelay) - - assertEquals(SelectLocationUiState.Loading, awaitItem()) val actualState = awaitItem() assertIs(actualState) assertLists(mockCountries, actualState.countries) @@ -122,8 +101,8 @@ class SelectLocationViewModelTest { val mockLocation: GeographicLocationConstraint.Country = mockk(relaxed = true) val connectionProxyMock: ConnectionProxy = mockk(relaxUnitFun = true) every { mockRelayItem.location } returns mockLocation - every { mockServiceConnectionManager.relayListListener() } returns mockRelayListListener every { mockServiceConnectionManager.connectionProxy() } returns connectionProxyMock + every { mockRelayListUseCase.updateSelectedRelayLocation(mockLocation) } returns Unit // Act, Assert viewModel.uiCloseAction.test { @@ -132,7 +111,7 @@ class SelectLocationViewModelTest { assertEquals(Unit, awaitItem()) verify { connectionProxyMock.connect() - mockRelayListListener.updateSelectedRelayLocation(mockLocation) + mockRelayListUseCase.updateSelectedRelayLocation(mockLocation) } } } @@ -146,15 +125,10 @@ class SelectLocationViewModelTest { val mockSearchString = "SEARCH" every { mockRelayList.filterOnSearchTerm(mockSearchString, selectedRelay) } returns mockCountries + relayListWithSelectionFlow.value = RelayList(mockRelayList, selectedRelay) // Act, Assert viewModel.uiState.test { - serviceConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - relaySlot.captured.invoke(mockRelayList, selectedRelay) - - // Wait for loading - assertEquals(SelectLocationUiState.Loading, awaitItem()) // Wait for first data assertIs(awaitItem()) @@ -178,15 +152,10 @@ class SelectLocationViewModelTest { val mockSearchString = "SEARCH" every { mockRelayList.filterOnSearchTerm(mockSearchString, selectedRelay) } returns mockCountries + relayListWithSelectionFlow.value = RelayList(mockRelayList, selectedRelay) // Act, Assert viewModel.uiState.test { - serviceConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - relaySlot.captured.invoke(mockRelayList, selectedRelay) - - // Wait for loading - assertEquals(SelectLocationUiState.Loading, awaitItem()) // Wait for first data assertIs(awaitItem()) diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModelTest.kt index 13561737c830..f8736eb823ea 100644 --- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModelTest.kt +++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/VpnSettingsViewModelTest.kt @@ -5,7 +5,6 @@ import androidx.lifecycle.viewModelScope import app.cash.turbine.test import io.mockk.every import io.mockk.mockk -import io.mockk.slot import io.mockk.unmockkAll import io.mockk.verify import kotlin.test.assertEquals @@ -30,10 +29,8 @@ import net.mullvad.mullvadvpn.model.TunnelOptions import net.mullvad.mullvadvpn.model.WireguardConstraints import net.mullvad.mullvadvpn.model.WireguardTunnelOptions import net.mullvad.mullvadvpn.repository.SettingsRepository -import net.mullvad.mullvadvpn.ui.serviceconnection.RelayListListener -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionContainer -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager -import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionState +import net.mullvad.mullvadvpn.usecase.PortRangeUseCase +import net.mullvad.mullvadvpn.usecase.RelayListUseCase import org.apache.commons.validator.routines.InetAddressValidator import org.junit.After import org.junit.Before @@ -46,34 +43,26 @@ class VpnSettingsViewModelTest { private val mockSettingsRepository: SettingsRepository = mockk() private val mockInetAddressValidator: InetAddressValidator = mockk() private val mockResources: Resources = mockk() - private val mockServiceConnectionManager: ServiceConnectionManager = mockk() - - private val mockServiceConnectionContainer: ServiceConnectionContainer = mockk() - private val mockRelayListListener: RelayListListener = mockk() - private val portRangeSlot = slot<(List) -> Unit>() + private val mockPortRangeUseCase: PortRangeUseCase = mockk() + private val mockRelayListUseCase: RelayListUseCase = mockk() private val mockSettingsUpdate = MutableStateFlow(null) - private val mockConnectionState = - MutableStateFlow(ServiceConnectionState.Disconnected) + private val portRangeFlow = MutableStateFlow(emptyList()) private lateinit var viewModel: VpnSettingsViewModel @Before fun setUp() { every { mockSettingsRepository.settingsUpdates } returns mockSettingsUpdate - every { mockServiceConnectionManager.connectionState } returns mockConnectionState - - every { mockServiceConnectionContainer.relayListListener } returns mockRelayListListener - - every { mockRelayListListener.onPortRangesChange = capture(portRangeSlot) } answers {} - every { mockRelayListListener.onPortRangesChange = null } answers {} + every { mockPortRangeUseCase.portRanges() } returns portRangeFlow viewModel = VpnSettingsViewModel( repository = mockSettingsRepository, inetAddressValidator = mockInetAddressValidator, resources = mockResources, - serviceConnectionManager = mockServiceConnectionManager, + portRangeUseCase = mockPortRangeUseCase, + relayListUseCase = mockRelayListUseCase, dispatcher = UnconfinedTestDispatcher() ) } @@ -122,9 +111,6 @@ class VpnSettingsViewModelTest { viewModel.uiState.test { assertEquals(defaultResistantState, awaitItem().quantumResistant) mockSettingsUpdate.value = mockSettings - mockConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - portRangeSlot.captured.invoke(emptyList()) assertEquals(expectedResistantState, awaitItem().quantumResistant) } } @@ -147,9 +133,6 @@ class VpnSettingsViewModelTest { viewModel.uiState.test { assertIs>(awaitItem().selectedWireguardPort) mockSettingsUpdate.value = mockSettings - mockConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - portRangeSlot.captured.invoke(emptyList()) assertEquals(expectedPort, awaitItem().selectedWireguardPort) } } @@ -159,18 +142,14 @@ class VpnSettingsViewModelTest { // Arrange val wireguardPort: Constraint = Constraint.Only(Port(99)) val wireguardConstraints = WireguardConstraints(port = wireguardPort) - every { - mockRelayListListener.updateSelectedWireguardConstraints(wireguardConstraints) - } returns Unit + every { mockRelayListUseCase.updateSelectedWireguardConstraints(any()) } returns Unit // Act - mockConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) viewModel.onWireguardPortSelected(wireguardPort) // Assert verify(exactly = 1) { - mockRelayListListener.updateSelectedWireguardConstraints(wireguardConstraints) + mockRelayListUseCase.updateSelectedWireguardConstraints(wireguardConstraints) } } @@ -181,15 +160,12 @@ class VpnSettingsViewModelTest { val mockSettings: Settings = mockk(relaxed = true) every { mockSettings.relaySettings } returns mockk(relaxed = true) + portRangeFlow.value = expectedPortRange // Act, Assert viewModel.uiState.test { assertIs(awaitItem()) - mockSettingsUpdate.value = mockSettings viewModel.onWireguardPortInfoClicked() - mockConnectionState.value = - ServiceConnectionState.ConnectedReady(mockServiceConnectionContainer) - portRangeSlot.captured.invoke(expectedPortRange) val state = awaitItem() assertTrue { state.dialog is VpnSettingsDialog.WireguardPortInfo } assertLists(expectedPortRange, state.availablePortRanges) diff --git a/android/lib/ipc/src/main/kotlin/net/mullvad/mullvadvpn/lib/ipc/Request.kt b/android/lib/ipc/src/main/kotlin/net/mullvad/mullvadvpn/lib/ipc/Request.kt index cbd1f28b2750..38237e84b3e2 100644 --- a/android/lib/ipc/src/main/kotlin/net/mullvad/mullvadvpn/lib/ipc/Request.kt +++ b/android/lib/ipc/src/main/kotlin/net/mullvad/mullvadvpn/lib/ipc/Request.kt @@ -4,9 +4,12 @@ import android.os.Message as RawMessage import android.os.Messenger import java.net.InetAddress import kotlinx.parcelize.Parcelize +import net.mullvad.mullvadvpn.model.Constraint import net.mullvad.mullvadvpn.model.DnsOptions import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.model.ObfuscationSettings +import net.mullvad.mullvadvpn.model.Ownership +import net.mullvad.mullvadvpn.model.Providers import net.mullvad.mullvadvpn.model.QuantumResistantState import net.mullvad.mullvadvpn.model.WireguardConstraints @@ -95,6 +98,12 @@ sealed class Request : Message.RequestMessage() { data class SetWireGuardQuantumResistant(val quantumResistant: QuantumResistantState) : Request() + @Parcelize data class SetOwnership(val ownership: Constraint) : Request() + + @Parcelize data class SetProviders(val providers: Constraint) : Request() + + @Parcelize data object FetchRelayList : Request() + companion object { private const val MESSAGE_KEY = "request" diff --git a/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelaySettings.kt b/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelaySettings.kt index 381305f2c345..642046f1b84b 100644 --- a/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelaySettings.kt +++ b/android/lib/model/src/main/kotlin/net/mullvad/mullvadvpn/model/RelaySettings.kt @@ -7,4 +7,6 @@ sealed class RelaySettings : Parcelable { @Parcelize data object CustomTunnelEndpoint : RelaySettings() @Parcelize data class Normal(val relayConstraints: RelayConstraints) : RelaySettings() + + fun relayConstraints(): RelayConstraints? = (this as? Normal)?.relayConstraints } diff --git a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt index 7a0b3fbe970e..186ac21092cf 100644 --- a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt +++ b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/endpoint/RelayListListener.kt @@ -1,16 +1,16 @@ package net.mullvad.mullvadvpn.service.endpoint import kotlin.properties.Delegates.observable +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.GlobalScope -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.channels.ClosedReceiveChannelException -import kotlinx.coroutines.channels.actor -import kotlinx.coroutines.channels.trySendBlocking +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.launch import net.mullvad.mullvadvpn.lib.ipc.Event import net.mullvad.mullvadvpn.lib.ipc.Request import net.mullvad.mullvadvpn.model.Constraint -import net.mullvad.mullvadvpn.model.GeographicLocationConstraint import net.mullvad.mullvadvpn.model.LocationConstraint import net.mullvad.mullvadvpn.model.RelayConstraints import net.mullvad.mullvadvpn.model.RelayList @@ -18,20 +18,13 @@ import net.mullvad.mullvadvpn.model.RelaySettings import net.mullvad.mullvadvpn.model.WireguardConstraints import net.mullvad.mullvadvpn.service.MullvadDaemon -class RelayListListener(endpoint: ServiceEndpoint) { - - private val commandChannel = spawnActor() +class RelayListListener( + endpoint: ServiceEndpoint, + dispatcher: CoroutineDispatcher = Dispatchers.IO +) { + private val scope: CoroutineScope = CoroutineScope(SupervisorJob() + dispatcher) private val daemon = endpoint.intermittentDaemon - private var selectedRelayLocation by - observable(null) { _, _, _ -> - commandChannel.trySendBlocking(Command.SetRelayLocation) - } - private var selectedWireguardConstraints by - observable(null) { _, _, _ -> - commandChannel.trySendBlocking(Command.SetWireguardConstraints) - } - var relayList by observable(null) { _, _, relays -> endpoint.sendEvent(Event.NewRelayList(relays)) @@ -46,18 +39,59 @@ class RelayListListener(endpoint: ServiceEndpoint) { } } - endpoint.dispatcher.registerHandler(Request.SetRelayLocation::class) { request -> - selectedRelayLocation = request.relayLocation + scope.launch { + endpoint.dispatcher.parsedMessages + .filterIsInstance() + .collect { request -> + val update = + getCurrentRelayConstraints() + .copy( + location = + Constraint.Only( + LocationConstraint.Location(request.relayLocation) + ) + ) + daemon.await().setRelaySettings(RelaySettings.Normal(update)) + } + } + + scope.launch { + endpoint.dispatcher.parsedMessages + .filterIsInstance() + .collect { request -> + val update = + getCurrentRelayConstraints() + .copy(wireguardConstraints = request.wireguardConstraints) + daemon.await().setRelaySettings(RelaySettings.Normal(update)) + } } - endpoint.dispatcher.registerHandler(Request.SetWireguardConstraints::class) { request -> - selectedWireguardConstraints = request.wireguardConstraints + scope.launch { + endpoint.dispatcher.parsedMessages.filterIsInstance().collect { + request -> + val update = getCurrentRelayConstraints().copy(ownership = request.ownership) + daemon.await().setRelaySettings(RelaySettings.Normal(update)) + } + } + + scope.launch { + endpoint.dispatcher.parsedMessages.filterIsInstance().collect { + request -> + val update = getCurrentRelayConstraints().copy(providers = request.providers) + daemon.await().setRelaySettings(RelaySettings.Normal(update)) + } + } + + scope.launch { + endpoint.dispatcher.parsedMessages.filterIsInstance().collect { + relayList = daemon.await().getRelayLocations() + } } } fun onDestroy() { - commandChannel.close() daemon.unregisterListener(this) + scope.cancel() } private fun setUpListener(daemon: MullvadDaemon) { @@ -72,43 +106,6 @@ class RelayListListener(endpoint: ServiceEndpoint) { } } - private fun spawnActor() = - GlobalScope.actor(Dispatchers.Default, Channel.CONFLATED) { - try { - for (command in channel) { - when (command) { - Command.SetRelayLocation, - Command.SetWireguardConstraints -> updateRelayConstraints() - } - } - } catch (exception: ClosedReceiveChannelException) { - // Closed sender, so stop the actor - } - } - - private suspend fun updateRelayConstraints() { - val currentRelayConstraints = getCurrentRelayConstraints() - val location: Constraint = - selectedRelayLocation?.let { location -> - Constraint.Only(LocationConstraint.Location(location)) - } - ?: currentRelayConstraints.location - val wireguardConstraints: WireguardConstraints = - selectedWireguardConstraints ?: currentRelayConstraints.wireguardConstraints - - val update = - RelaySettings.Normal( - RelayConstraints( - location = location, - wireguardConstraints = wireguardConstraints, - ownership = Constraint.Any(), - providers = Constraint.Any() - ) - ) - - daemon.await().setRelaySettings(update) - } - private suspend fun getCurrentRelayConstraints(): RelayConstraints = when (val relaySettings = daemon.await().getSettings()?.relaySettings) { is RelaySettings.Normal -> relaySettings.relayConstraints @@ -120,11 +117,4 @@ class RelayListListener(endpoint: ServiceEndpoint) { wireguardConstraints = WireguardConstraints(Constraint.Any()) ) } - - companion object { - private enum class Command { - SetRelayLocation, - SetWireguardConstraints - } - } }