diff --git a/Cargo.lock b/Cargo.lock index e7a426717dff..d280ba649b82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -431,9 +431,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" [[package]] name = "c2rust-bitfields" @@ -1248,9 +1248,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1279,9 +1279,9 @@ checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -2107,9 +2107,9 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jnix" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd797d41e48568eb956ded20d7e5e3f2df1c02980d9e5b9aab9b47bd3a9f599" +checksum = "542b2072131a62ec940ee161ff0a01e7a1c2a129796b30143efc952cb6e0f28f" dependencies = [ "jni", "jnix-macros", @@ -2196,9 +2196,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.158" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libdbus-sys" @@ -2888,9 +2888,9 @@ dependencies = [ [[package]] name = "netlink-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "416060d346fbaf1f23f9512963e3e878f1a78e707cb699ba9215761754244307" +checksum = "16c903aa70590cb93691bf97a767c8d1d6122d2cc9070433deb3bbf36ce8bd23" dependencies = [ "bytes", "futures", @@ -4789,6 +4789,7 @@ dependencies = [ "bitflags 2.6.0", "futures", "ipnetwork", + "jnix", "libc", "log", "netlink-packet-route", diff --git a/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt new file mode 100644 index 000000000000..27e7658a11de --- /dev/null +++ b/android/app/src/test/kotlin/net/mullvad/talpid/TalpidVpnServiceFallbackDnsTest.kt @@ -0,0 +1,146 @@ +package net.mullvad.talpid + +import android.net.VpnService +import android.os.ParcelFileDescriptor +import arrow.core.right +import io.mockk.MockKAnnotations +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkConstructor +import io.mockk.mockkStatic +import io.mockk.spyk +import java.net.InetAddress +import net.mullvad.mullvadvpn.lib.common.test.assertLists +import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe +import net.mullvad.mullvadvpn.lib.model.Prepared +import net.mullvad.talpid.model.CreateTunResult +import net.mullvad.talpid.model.InetNetwork +import net.mullvad.talpid.model.TunConfig +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertInstanceOf + +class TalpidVpnServiceFallbackDnsTest { + lateinit var talpidVpnService: TalpidVpnService + var builderMockk = mockk() + + @BeforeEach + fun setup() { + MockKAnnotations.init(this) + mockkStatic(VPN_SERVICE_EXTENSION) + + talpidVpnService = spyk(recordPrivateCalls = true) + every { talpidVpnService.prepareVpnSafe() } returns Prepared.right() + builderMockk = mockk() + + mockkConstructor(VpnService.Builder::class) + every { anyConstructed().setMtu(any()) } returns builderMockk + every { anyConstructed().setBlocking(any()) } returns builderMockk + every { anyConstructed().addAddress(any(), any()) } returns + builderMockk + every { anyConstructed().addRoute(any(), any()) } returns + builderMockk + every { + anyConstructed() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } returns builderMockk + val parcelFileDescriptor: ParcelFileDescriptor = mockk() + every { anyConstructed().establish() } returns parcelFileDescriptor + every { parcelFileDescriptor.detachFd() } returns 1 + } + + @Test + fun `opening tun with no DnsServers should add fallback DNS server`() { + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf()) + + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf(result) + + // Fallback DNS server should be added if no DNS servers are provided + coVerify(exactly = 1) { + anyConstructed() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + @Test + fun `opening tun with all bad DnsServers should return InvalidDnsServers and add fallback`() { + val badDns1 = InetAddress.getByName("0.0.0.0") + val badDns2 = InetAddress.getByName("255.255.255.255") + every { anyConstructed().addDnsServer(badDns1) } throws + IllegalArgumentException() + every { anyConstructed().addDnsServer(badDns2) } throws + IllegalArgumentException() + + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(badDns1, badDns2)) + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf(result) + assertLists(tunConfig.dnsServers, result.addresses) + // Fallback DNS server should be added if no valid DNS servers are provided + coVerify(exactly = 1) { + anyConstructed() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + @Test + fun `opening tun with 1 good and 1 bad DnsServers should return InvalidDnsServers`() { + val goodDnsServer = InetAddress.getByName("1.1.1.1") + val badDns = InetAddress.getByName("255.255.255.255") + every { anyConstructed().addDnsServer(goodDnsServer) } returns + builderMockk + every { anyConstructed().addDnsServer(badDns) } throws + IllegalArgumentException() + + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer, badDns)) + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf(result) + assertLists(arrayListOf(badDns), result.addresses) + + // Fallback DNS server should not be added since we have 1 good DNS server + coVerify(exactly = 0) { + anyConstructed() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + @Test + fun `providing good dns servers should not add the fallback dns and return success`() { + val goodDnsServer = InetAddress.getByName("1.1.1.1") + every { anyConstructed().addDnsServer(goodDnsServer) } returns + builderMockk + + val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer)) + val result = talpidVpnService.openTun(tunConfig) + + assertInstanceOf(result) + + // Fallback DNS server should not be added since we have good DNS servers. + coVerify(exactly = 0) { + anyConstructed() + .addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER) + } + } + + companion object { + private const val VPN_SERVICE_EXTENSION = + "net.mullvad.mullvadvpn.lib.common.util.VpnServiceUtilsKt" + + val baseTunConfig = + TunConfig( + addresses = arrayListOf(InetAddress.getByName("45.83.223.209")), + dnsServers = arrayListOf(), + routes = + arrayListOf( + InetNetwork(InetAddress.getByName("0.0.0.0"), 0), + InetNetwork(InetAddress.getByName("::"), 0), + ), + mtu = 1280, + excludedPackages = arrayListOf(), + ) + } +} diff --git a/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt b/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt index 59833cb3961d..06c862936b65 100644 --- a/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt +++ b/android/lib/common/src/main/kotlin/net/mullvad/mullvadvpn/lib/common/util/VpnServiceUtils.kt @@ -2,10 +2,14 @@ package net.mullvad.mullvadvpn.lib.common.util import android.content.Context import android.content.Intent +import android.net.VpnService import android.net.VpnService.prepare +import android.os.ParcelFileDescriptor import arrow.core.Either -import arrow.core.flatten +import arrow.core.flatMap import arrow.core.left +import arrow.core.raise.either +import arrow.core.raise.ensureNotNull import arrow.core.right import co.touchlab.kermit.Logger import net.mullvad.mullvadvpn.lib.common.util.SdkUtils.getInstalledPackagesList @@ -13,6 +17,8 @@ import net.mullvad.mullvadvpn.lib.model.PrepareError import net.mullvad.mullvadvpn.lib.model.Prepared /** + * Prepare to establish a VPN connection safely. + * * Invoking VpnService.prepare() can result in 3 out comes: * 1. IllegalStateException - There is a legacy VPN profile marked as always on * 2. Intent @@ -34,7 +40,7 @@ fun Context.prepareVpnSafe(): Either = else -> throw it } } - .map { intent -> + .flatMap { intent -> if (intent == null) { Prepared.right() } else { @@ -46,7 +52,6 @@ fun Context.prepareVpnSafe(): Either = } } } - .flatten() fun Context.getAlwaysOnVpnAppName(): String? { return resolveAlwaysOnVpnPackageName() @@ -59,3 +64,38 @@ fun Context.getAlwaysOnVpnAppName(): String? { ?.loadLabel(packageManager) ?.toString() } + +/** + * Establish a VPN connection safely. + * + * This function wraps the [VpnService.Builder.establish] function and catches any exceptions that + * may be thrown and type them to a more specific error. + * + * @return [ParcelFileDescriptor] if successful, [EstablishError] otherwise + */ +fun VpnService.Builder.establishSafe(): Either = either { + val vpnInterfaceFd = + Either.catch { establish() } + .mapLeft { + when (it) { + is IllegalStateException -> EstablishError.ParameterNotApplied(it) + is IllegalArgumentException -> EstablishError.ParameterNotAccepted(it) + else -> EstablishError.UnknownError(it) + } + } + .bind() + + ensureNotNull(vpnInterfaceFd) { EstablishError.NullVpnInterface } + + vpnInterfaceFd +} + +sealed interface EstablishError { + data class ParameterNotApplied(val exception: IllegalStateException) : EstablishError + + data class ParameterNotAccepted(val exception: IllegalArgumentException) : EstablishError + + data object NullVpnInterface : EstablishError + + data class UnknownError(val error: Throwable) : EstablishError +} diff --git a/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt b/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt index daa04fc8d996..fe4cf11881ba 100644 --- a/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt +++ b/android/lib/daemon-grpc/src/main/kotlin/net/mullvad/mullvadvpn/lib/daemon/grpc/mapper/ToDomain.kt @@ -36,9 +36,6 @@ import net.mullvad.mullvadvpn.lib.model.DnsState import net.mullvad.mullvadvpn.lib.model.Endpoint import net.mullvad.mullvadvpn.lib.model.ErrorState import net.mullvad.mullvadvpn.lib.model.ErrorStateCause -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.AuthFailed -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.OtherAlwaysOnApp -import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.TunnelParameterError import net.mullvad.mullvadvpn.lib.model.FeatureIndicator import net.mullvad.mullvadvpn.lib.model.GeoIpLocation import net.mullvad.mullvadvpn.lib.model.GeoLocationId @@ -125,7 +122,7 @@ private fun ManagementInterface.TunnelState.Error.toDomain(): TunnelState.Error val otherAlwaysOnAppError = errorState.let { if (it.hasOtherAlwaysOnAppError()) { - OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName) + ErrorStateCause.OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName) } else { null } @@ -238,7 +235,7 @@ internal fun ManagementInterface.ErrorState.toDomain( cause = when (cause!!) { ManagementInterface.ErrorState.Cause.AUTH_FAILED -> - AuthFailed(authFailedError.toDomain()) + ErrorStateCause.AuthFailed(authFailedError.toDomain()) ManagementInterface.ErrorState.Cause.IPV6_UNAVAILABLE -> ErrorStateCause.Ipv6Unavailable ManagementInterface.ErrorState.Cause.SET_FIREWALL_POLICY_ERROR -> @@ -247,7 +244,7 @@ internal fun ManagementInterface.ErrorState.toDomain( ManagementInterface.ErrorState.Cause.START_TUNNEL_ERROR -> ErrorStateCause.StartTunnelError ManagementInterface.ErrorState.Cause.TUNNEL_PARAMETER_ERROR -> - TunnelParameterError(parameterError.toDomain()) + ErrorStateCause.TunnelParameterError(parameterError.toDomain()) ManagementInterface.ErrorState.Cause.IS_OFFLINE -> ErrorStateCause.IsOffline ManagementInterface.ErrorState.Cause.SPLIT_TUNNEL_ERROR -> ErrorStateCause.StartTunnelError @@ -255,7 +252,6 @@ internal fun ManagementInterface.ErrorState.toDomain( ManagementInterface.ErrorState.Cause.NEED_FULL_DISK_PERMISSIONS, ManagementInterface.ErrorState.Cause.CREATE_TUNNEL_DEVICE -> throw IllegalArgumentException("Unrecognized error state cause") - ManagementInterface.ErrorState.Cause.NOT_PREPARED -> ErrorStateCause.NotPrepared ManagementInterface.ErrorState.Cause.OTHER_ALWAYS_ON_APP -> otherAlwaysOnApp!! ManagementInterface.ErrorState.Cause.OTHER_LEGACY_ALWAYS_ON_VPN -> diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index 86b27e3ba83d..fdee5039ade6 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -7,34 +7,48 @@ import android.net.NetworkCapabilities import android.net.NetworkRequest import co.touchlab.kermit.Logger import java.net.InetAddress +import kotlin.collections.ArrayList import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.distinctUntilChanged -import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.stateIn +import net.mullvad.talpid.model.NetworkState import net.mullvad.talpid.util.NetworkEvent -import net.mullvad.talpid.util.defaultNetworkFlow -import net.mullvad.talpid.util.networkFlow +import net.mullvad.talpid.util.RawNetworkState +import net.mullvad.talpid.util.defaultRawNetworkStateFlow +import net.mullvad.talpid.util.networkEvents -class ConnectivityListener(val connectivityManager: ConnectivityManager) { +class ConnectivityListener(private val connectivityManager: ConnectivityManager) { private lateinit var _isConnected: StateFlow // Used by JNI val isConnected get() = _isConnected.value - private lateinit var _currentDnsServers: StateFlow> + private lateinit var _currentNetworkState: StateFlow + + // Used by JNI + val currentDefaultNetworkState: NetworkState? + get() = _currentNetworkState.value + // Used by JNI - val currentDnsServers - get() = ArrayList(_currentDnsServers.value) + val currentDnsServers: ArrayList + get() = _currentNetworkState.value?.dnsServers ?: ArrayList() fun register(scope: CoroutineScope) { - _currentDnsServers = - dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers()) + // Consider implementing retry logic for the flows below, because registering a listener on + // the default network may fail if the network on Android 11 + // https://issuetracker.google.com/issues/175055271?pli=1 + _currentNetworkState = + connectivityManager + .defaultRawNetworkStateFlow() + .map { it?.toNetworkState() } + .onEach { notifyDefaultNetworkChange(it) } + .stateIn(scope, SharingStarted.Eagerly, null) _isConnected = hasInternetCapability() @@ -42,18 +56,6 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .stateIn(scope, SharingStarted.Eagerly, false) } - private fun dnsServerChanges(): Flow> = - connectivityManager - .defaultNetworkFlow() - .filterIsInstance() - .onEach { Logger.d("Link properties changed") } - .map { it.linkProperties.dnsServersWithoutFallback() } - - private fun currentDnsServers(): List = - connectivityManager - .getLinkProperties(connectivityManager.activeNetwork) - ?.dnsServersWithoutFallback() ?: emptyList() - private fun LinkProperties.dnsServersWithoutFallback(): List = dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } @@ -65,7 +67,7 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .build() return connectivityManager - .networkFlow(request) + .networkEvents(request) .scan(setOf()) { networks, event -> when (event) { is NetworkEvent.Available -> { @@ -87,5 +89,14 @@ class ConnectivityListener(val connectivityManager: ConnectivityManager) { .distinctUntilChanged() } + private fun RawNetworkState.toNetworkState(): NetworkState = + NetworkState( + network.networkHandle, + linkProperties?.routes, + linkProperties?.dnsServersWithoutFallback(), + ) + private external fun notifyConnectivityChange(isConnected: Boolean) + + private external fun notifyDefaultNetworkChange(networkState: NetworkState?) } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 74d44005cd7c..a143df61322e 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -1,18 +1,29 @@ package net.mullvad.talpid import android.net.ConnectivityManager +import android.net.VpnService import android.os.ParcelFileDescriptor import androidx.annotation.CallSuper import androidx.core.content.getSystemService import androidx.lifecycle.lifecycleScope +import arrow.core.Either +import arrow.core.mapOrAccumulate +import arrow.core.merge +import arrow.core.raise.either import co.touchlab.kermit.Logger import java.net.Inet4Address import java.net.Inet6Address import java.net.InetAddress import kotlin.properties.Delegates.observable +import net.mullvad.mullvadvpn.lib.common.util.establishSafe import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe import net.mullvad.mullvadvpn.lib.model.PrepareError import net.mullvad.talpid.model.CreateTunResult +import net.mullvad.talpid.model.CreateTunResult.EstablishError +import net.mullvad.talpid.model.CreateTunResult.InvalidDnsServers +import net.mullvad.talpid.model.CreateTunResult.NotPrepared +import net.mullvad.talpid.model.CreateTunResult.OtherAlwaysOnApp +import net.mullvad.talpid.model.CreateTunResult.OtherLegacyAlwaysOnVpn import net.mullvad.talpid.model.TunConfig import net.mullvad.talpid.util.TalpidSdkUtils.setMeteredIfSupported @@ -22,7 +33,7 @@ open class TalpidVpnService : LifecycleVpnService() { val oldTunFd = when (oldTunStatus) { is CreateTunResult.Success -> oldTunStatus.tunFd - is CreateTunResult.InvalidDnsServers -> oldTunStatus.tunFd + is InvalidDnsServers -> oldTunStatus.tunFd else -> null } @@ -43,26 +54,30 @@ open class TalpidVpnService : LifecycleVpnService() { connectivityListener.register(lifecycleScope) } - fun openTun(config: TunConfig): CreateTunResult { + // Used by JNI + fun openTun(config: TunConfig): CreateTunResult = synchronized(this) { val tunStatus = activeTunStatus if (config == currentTunConfig && tunStatus != null && tunStatus.isOpen) { - return tunStatus + tunStatus } else { - return openTunImpl(config) + openTunImpl(config) } } - } - fun openTunForced(config: TunConfig): CreateTunResult { - synchronized(this) { - return openTunImpl(config) - } - } + // Used by JNI + fun openTunForced(config: TunConfig): CreateTunResult = + synchronized(this) { openTunImpl(config) } + + // Used by JNI + fun closeTun(): Unit = synchronized(this) { activeTunStatus = null } + + // Used by JNI + fun bypass(socket: Int): Boolean = protect(socket) private fun openTunImpl(config: TunConfig): CreateTunResult { - val newTunStatus = createTun(config) + val newTunStatus = createTun(config).merge() currentTunConfig = config activeTunStatus = newTunStatus @@ -70,95 +85,76 @@ open class TalpidVpnService : LifecycleVpnService() { return newTunStatus } - fun closeTun() { - synchronized(this) { activeTunStatus = null } - } - - // DROID-1407 - // Function is to be cleaned up and lint suppression to be removed. - @Suppress("ReturnCount") - private fun createTun(config: TunConfig): CreateTunResult { - prepareVpnSafe() - .mapLeft { it.toCreateTunResult() } - .onLeft { - return it + private fun createTun( + config: TunConfig + ): Either = either { + prepareVpnSafe().mapLeft { it.toCreateTunError() }.bind() + + val builder = Builder() + builder.setMtu(config.mtu) + builder.setBlocking(false) + builder.setMeteredIfSupported(false) + + config.addresses.forEach { builder.addAddress(it, it.prefixLength()) } + config.routes.forEach { builder.addRoute(it.address, it.prefixLength.toInt()) } + config.excludedPackages.forEach { app -> builder.addDisallowedApplication(app) } + + // We don't care if adding DNS servers fails at this point, since we can still create a + // tunnel to consume traffic and then notify daemon to later enter blocked state. + val dnsConfigureResult = + config.dnsServers.mapOrAccumulate { + builder.addDnsServerSafe(it).bind() + Unit } - val invalidDnsServerAddresses = ArrayList() - - val builder = - Builder().apply { - for (address in config.addresses) { - addAddress(address, address.prefixLength()) - } - - for (dnsServer in config.dnsServers) { - try { - addDnsServer(dnsServer) - } catch (exception: IllegalArgumentException) { - invalidDnsServerAddresses.add(dnsServer) - } - } - - // Avoids creating a tunnel with no DNS servers or if all DNS servers was invalid, - // since apps then may leak DNS requests. - // https://issuetracker.google.com/issues/337961996 - if (invalidDnsServerAddresses.size == config.dnsServers.size) { - Logger.w( - "All DNS servers invalid or non set, using fallback DNS server to " + - "minimize leaks, dnsServers.isEmpty(): ${config.dnsServers.isEmpty()}" - ) - addDnsServer(FALLBACK_DUMMY_DNS_SERVER) - } - - for (route in config.routes) { - addRoute(route.address, route.prefixLength.toInt()) - } - - config.excludedPackages.forEach { app -> addDisallowedApplication(app) } - setMtu(config.mtu) - setBlocking(false) - setMeteredIfSupported(false) - } + // Never create a tunnel where all DNS servers are invalid or if none was ever set, since + // apps then may leak DNS requests. + // https://issuetracker.google.com/issues/337961996 + val shouldAddFallbackDns = + dnsConfigureResult.fold( + { invalidDnsServers -> invalidDnsServers.size == config.dnsServers.size }, + { addedDnsServers -> addedDnsServers.isEmpty() }, + ) + if (shouldAddFallbackDns) { + Logger.w( + "All DNS servers invalid or non set, using fallback DNS server to " + + "minimize leaks, dnsServers.isEmpty(): ${config.dnsServers.isEmpty()}" + ) + builder.addDnsServer(FALLBACK_DUMMY_DNS_SERVER) + } val vpnInterfaceFd = - try { - builder.establish() - } catch (e: IllegalStateException) { - Logger.e("Failed to establish, a parameter could not be applied", e) - return CreateTunResult.TunnelDeviceError - } catch (e: IllegalArgumentException) { - Logger.e("Failed to establish a parameter was not accepted", e) - return CreateTunResult.TunnelDeviceError - } - - if (vpnInterfaceFd == null) { - Logger.e("VpnInterface returned null") - return CreateTunResult.TunnelDeviceError - } + builder + .establishSafe() + .onLeft { Logger.w("Failed to establish tunnel $it") } + .mapLeft { EstablishError } + .bind() val tunFd = vpnInterfaceFd.detachFd() - waitForTunnelUp(tunFd, config.routes.any { route -> route.isIpv6 }) + dnsConfigureResult.mapLeft { InvalidDnsServers(it, tunFd) }.bind() - if (invalidDnsServerAddresses.isNotEmpty()) { - return CreateTunResult.InvalidDnsServers(invalidDnsServerAddresses, tunFd) - } - - return CreateTunResult.Success(tunFd) - } - - fun bypass(socket: Int): Boolean { - return protect(socket) + CreateTunResult.Success(tunFd) } - private fun PrepareError.toCreateTunResult() = + private fun PrepareError.toCreateTunError() = when (this) { - is PrepareError.OtherLegacyAlwaysOnVpn -> CreateTunResult.OtherLegacyAlwaysOnVpn - is PrepareError.NotPrepared -> CreateTunResult.NotPrepared - is PrepareError.OtherAlwaysOnApp -> CreateTunResult.OtherAlwaysOnApp(appName) + is PrepareError.OtherLegacyAlwaysOnVpn -> OtherLegacyAlwaysOnVpn + is PrepareError.NotPrepared -> NotPrepared + is PrepareError.OtherAlwaysOnApp -> OtherAlwaysOnApp(appName) } + private fun Builder.addDnsServerSafe( + dnsServer: InetAddress + ): Either = + Either.catch { addDnsServer(dnsServer) } + .mapLeft { + when (it) { + is IllegalArgumentException -> dnsServer + else -> throw it + } + } + private fun InetAddress.prefixLength(): Int = when (this) { is Inet4Address -> IPV4_PREFIX_LENGTH @@ -166,8 +162,6 @@ open class TalpidVpnService : LifecycleVpnService() { else -> throw IllegalArgumentException("Invalid IP address (not IPv4 nor IPv6)") } - private external fun waitForTunnelUp(tunFd: Int, isIpv6Enabled: Boolean) - companion object { const val FALLBACK_DUMMY_DNS_SERVER = "192.0.2.1" diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt index 3cd73685f715..ef10dcd2f3b7 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/CreateTunResult.kt @@ -1,29 +1,38 @@ package net.mullvad.talpid.model import java.net.InetAddress +import java.util.ArrayList -sealed class CreateTunResult { - open val isOpen - get() = false +sealed interface CreateTunResult { + val isOpen: Boolean - class Success(val tunFd: Int) : CreateTunResult() { - override val isOpen - get() = true + data class Success(val tunFd: Int) : CreateTunResult { + override val isOpen = true } - class InvalidDnsServers(val addresses: ArrayList, val tunFd: Int) : - CreateTunResult() { - override val isOpen - get() = true + sealed interface Error : CreateTunResult + + // Prepare errors + data object OtherLegacyAlwaysOnVpn : Error { + override val isOpen: Boolean = false } - // Establish error - data object TunnelDeviceError : CreateTunResult() + data class OtherAlwaysOnApp(val appName: String) : Error { + override val isOpen: Boolean = false + } - // Prepare errors - data object OtherLegacyAlwaysOnVpn : CreateTunResult() + data object NotPrepared : Error { + override val isOpen: Boolean = false + } - data class OtherAlwaysOnApp(val appName: String) : CreateTunResult() + // Establish error + data object EstablishError : Error { + override val isOpen: Boolean = false + } - data object NotPrepared : CreateTunResult() + data class InvalidDnsServers(val addresses: ArrayList, val tunFd: Int) : Error { + constructor(address: List, tunFd: Int) : this(ArrayList(address), tunFd) + + override val isOpen = true + } } diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt new file mode 100644 index 000000000000..ca0b6db7e22a --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/NetworkState.kt @@ -0,0 +1,19 @@ +package net.mullvad.talpid.model + +import java.net.InetAddress + +data class NetworkState( + val networkHandle: Long, + val routes: ArrayList?, + val dnsServers: ArrayList?, +) { + constructor( + networkHandle: Long, + routes: List?, + dnsServers: List?, + ) : this( + networkHandle = networkHandle, + routes = routes?.map { it.toRoute() }?.let { ArrayList(it) }, + dnsServers = dnsServers?.let { ArrayList(it) }, + ) +} diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt new file mode 100644 index 000000000000..035bdc1ad086 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/model/RouteInfo.kt @@ -0,0 +1,18 @@ +package net.mullvad.talpid.model + +import java.net.InetAddress + +typealias AndroidRouteInfo = android.net.RouteInfo + +data class RouteInfo( + val destination: InetNetwork, + val gateway: InetAddress?, + val `interface`: String?, +) + +fun AndroidRouteInfo.toRoute() = + RouteInfo( + destination = InetNetwork(destination.address, destination.prefixLength.toShort()), + gateway = gateway, + `interface` = `interface`, + ) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt index daf155c6e8ef..fddaa6fb8806 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -10,59 +10,56 @@ import kotlinx.coroutines.channels.awaitClose import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.scan + +internal fun ConnectivityManager.defaultNetworkEvents(): Flow = callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } -fun ConnectivityManager.defaultNetworkFlow(): Flow = - callbackFlow { - val callback = - object : NetworkCallback() { - override fun onLinkPropertiesChanged( - network: Network, - linkProperties: LinkProperties, - ) { - super.onLinkPropertiesChanged(network, linkProperties) - trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) - } - - override fun onAvailable(network: Network) { - super.onAvailable(network) - trySendBlocking(NetworkEvent.Available(network)) - } + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } - override fun onCapabilitiesChanged( - network: Network, - networkCapabilities: NetworkCapabilities, - ) { - super.onCapabilitiesChanged(network, networkCapabilities) - trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) - } + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } - override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { - super.onBlockedStatusChanged(network, blocked) - trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) - } + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } - override fun onLosing(network: Network, maxMsToLive: Int) { - super.onLosing(network, maxMsToLive) - trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) - } + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } - override fun onLost(network: Network) { - super.onLost(network) - trySendBlocking(NetworkEvent.Lost(network)) - } + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } - override fun onUnavailable() { - super.onUnavailable() - trySendBlocking(NetworkEvent.Unavailable) - } + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) } - registerDefaultNetworkCallback(callback) + } + registerDefaultNetworkCallback(callback) - awaitClose { unregisterNetworkCallback(callback) } - } + awaitClose { unregisterNetworkCallback(callback) } +} -fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow = - callbackFlow { +fun ConnectivityManager.networkEvents(networkRequest: NetworkRequest): Flow = + callbackFlow { val callback = object : NetworkCallback() { override fun onLinkPropertiesChanged( @@ -111,6 +108,26 @@ fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow = + defaultNetworkEvents() + .scan( + null as RawNetworkState?, + { state, event -> + return@scan when (event) { + is NetworkEvent.Available -> RawNetworkState(network = event.network) + is NetworkEvent.BlockedStatusChanged -> + state?.copy(blockedStatus = event.blocked) + is NetworkEvent.CapabilitiesChanged -> + state?.copy(networkCapabilities = event.networkCapabilities) + is NetworkEvent.LinkPropertiesChanged -> + state?.copy(linkProperties = event.linkProperties) + is NetworkEvent.Losing -> state?.copy(maxMsToLive = event.maxMsToLive) + is NetworkEvent.Lost -> null + NetworkEvent.Unavailable -> null + } + }, + ) + sealed interface NetworkEvent { data class Available(val network: Network) : NetworkEvent @@ -130,3 +147,11 @@ sealed interface NetworkEvent { data class Lost(val network: Network) : NetworkEvent } + +internal data class RawNetworkState( + val network: Network, + val linkProperties: LinkProperties? = null, + val networkCapabilities: NetworkCapabilities? = null, + val blockedStatus: Boolean = false, + val maxMsToLive: Int? = null, +) diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index fbd60a8e79ab..6e91e6efa133 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -790,6 +790,8 @@ impl Daemon { mullvad_types::TUNNEL_FWMARK, #[cfg(target_os = "linux")] mullvad_types::TUNNEL_TABLE_ID, + #[cfg(target_os = "android")] + config.android_context.clone(), ) .await .map_err(Error::RouteManager)?; diff --git a/mullvad-jni/src/classes.rs b/mullvad-jni/src/classes.rs index 8312657efb1e..f773d3adca24 100644 --- a/mullvad-jni/src/classes.rs +++ b/mullvad-jni/src/classes.rs @@ -7,12 +7,14 @@ pub const CLASSES: &[&str] = &[ "net/mullvad/mullvadvpn/service/MullvadVpnService", "net/mullvad/talpid/model/InetNetwork", "net/mullvad/talpid/model/TunConfig", + "net/mullvad/talpid/model/NetworkState", + "net/mullvad/talpid/model/RouteInfo", "net/mullvad/talpid/model/CreateTunResult$Success", "net/mullvad/talpid/model/CreateTunResult$InvalidDnsServers", "net/mullvad/talpid/model/CreateTunResult$OtherLegacyAlwaysOnVpn", "net/mullvad/talpid/model/CreateTunResult$OtherAlwaysOnApp", "net/mullvad/talpid/model/CreateTunResult$NotPrepared", - "net/mullvad/talpid/model/CreateTunResult$TunnelDeviceError", + "net/mullvad/talpid/model/CreateTunResult$EstablishError", "net/mullvad/talpid/ConnectivityListener", "net/mullvad/talpid/TalpidVpnService", "net/mullvad/mullvadvpn/lib/endpoint/ApiEndpointOverride", diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index fd35396fd00f..8b1018d926ab 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -3,7 +3,6 @@ mod api; mod classes; mod problem_report; -mod talpid_vpn_service; use jnix::{ jni::{ @@ -88,17 +87,25 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_initial assert!(ctx.is_none(), "multiple calls to MullvadDaemon.initialize"); let env = JnixEnv::from(env); + let files_dir = pathbuf_from_java(&env, files_directory); + start_logging(&files_dir) + .map_err(Error::InitializeLogging) + .unwrap(); + version::log_version(); + log::info!("Pre-loading classes!"); LOAD_CLASSES.call_once(|| env.preload_classes(classes::CLASSES.iter().cloned())); + log::info!("Done loading classes"); let rpc_socket = pathbuf_from_java(&env, rpc_socket_path); - let files_dir = pathbuf_from_java(&env, files_directory); let cache_dir = pathbuf_from_java(&env, cache_directory); let android_context = ok_or_throw!(&env, create_android_context(&env, vpn_service)); + log::info!("Created Android Context"); let api_endpoint = api::api_endpoint_from_java(&env, api_endpoint); + log::info!("Starting daemon"); let daemon = ok_or_throw!( &env, start( @@ -134,11 +141,8 @@ fn start( rpc_socket: PathBuf, files_dir: PathBuf, cache_dir: PathBuf, - api_endpoint: Option, + api_endpoint: Option, ) -> Result { - start_logging(&files_dir).map_err(Error::InitializeLogging)?; - version::log_version(); - #[cfg(not(feature = "api-override"))] if api_endpoint.is_some() { log::warn!("api_endpoint will be ignored since 'api-override' is not enabled"); diff --git a/mullvad-jni/src/talpid_vpn_service.rs b/mullvad-jni/src/talpid_vpn_service.rs deleted file mode 100644 index ea6928538a86..000000000000 --- a/mullvad-jni/src/talpid_vpn_service.rs +++ /dev/null @@ -1,181 +0,0 @@ -use ipnetwork::IpNetwork; -use jnix::jni::{ - objects::JObject, - sys::{jboolean, jint, JNI_FALSE}, - JNIEnv, -}; -use nix::sys::{ - select::{pselect, FdSet}, - time::{TimeSpec, TimeValLike}, -}; -use rand::{thread_rng, Rng}; -use std::{ - io, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, - os::unix::io::RawFd, - time::{Duration, Instant}, -}; -use talpid_types::ErrorExt; - -#[derive(Debug, thiserror::Error)] -enum Error { - #[error("Failed to verify the tunnel device")] - VerifyTunDevice(#[from] SendRandomDataError), - - #[error("Failed to select() on tunnel device")] - Select(#[from] nix::Error), - - #[error("Timed out while waiting for tunnel device to receive data")] - TunnelDeviceTimeout, -} - -#[no_mangle] -#[allow(non_snake_case)] -pub extern "system" fn Java_net_mullvad_talpid_TalpidVpnService_waitForTunnelUp( - _: JNIEnv<'_>, - _this: JObject<'_>, - tunFd: jint, - isIpv6Enabled: jboolean, -) { - let tun_fd = tunFd as RawFd; - let is_ipv6_enabled = isIpv6Enabled != JNI_FALSE; - - if let Err(error) = wait_for_tunnel_up(tun_fd, is_ipv6_enabled) { - log::error!( - "{}", - error.display_chain_with_msg("Failed to wait for tunnel device to be usable") - ); - } -} - -fn wait_for_tunnel_up(tun_fd: RawFd, is_ipv6_enabled: bool) -> Result<(), Error> { - let mut fd_set = FdSet::new(); - fd_set.insert(tun_fd); - let timeout = TimeSpec::microseconds(300); - const TIMEOUT: Duration = Duration::from_secs(60); - let start = Instant::now(); - while start.elapsed() < TIMEOUT { - // if tunnel device is ready to be read from, traffic is being routed through it - if pselect(None, Some(&mut fd_set), None, None, Some(&timeout), None)? > 0 { - return Ok(()); - } - // have to add tun_fd back into the bitset - fd_set.insert(tun_fd); - try_sending_random_udp(is_ipv6_enabled)?; - } - - Err(Error::TunnelDeviceTimeout) -} - -#[derive(Debug, thiserror::Error)] -enum SendRandomDataError { - #[error("Failed to bind an UDP socket")] - BindUdpSocket(#[source] io::Error), - - #[error("Failed to send random data through UDP socket")] - SendToUdpSocket(#[source] io::Error), -} - -fn try_sending_random_udp(is_ipv6_enabled: bool) -> Result<(), SendRandomDataError> { - let mut tried_ipv6 = false; - const TIMEOUT: Duration = Duration::from_millis(300); - let start = Instant::now(); - - while start.elapsed() < TIMEOUT { - // TODO: if we are to allow LAN on Android by changing the routes that are stuffed in - // TunConfig, then this should be revisited to be fair between IPv4 and IPv6 - let should_generate_ipv4 = !is_ipv6_enabled || tried_ipv6 || thread_rng().gen(); - let (bound_addr, random_public_addr) = random_socket_addrs(should_generate_ipv4); - - tried_ipv6 |= random_public_addr.ip().is_ipv6(); - - let socket = UdpSocket::bind(bound_addr).map_err(SendRandomDataError::BindUdpSocket)?; - match socket.send_to(&random_data(), random_public_addr) { - Ok(_) => return Ok(()), - // Always retry on IPv6 errors - Err(_) if random_public_addr.ip().is_ipv6() => continue, - Err(_err) if matches!(_err.raw_os_error(), Some(22) | Some(101)) => { - // Error code 101 - specified network is unreachable - // Error code 22 - specified address is not usable - continue; - } - Err(err) => return Err(SendRandomDataError::SendToUdpSocket(err)), - } - } - Ok(()) -} - -fn random_data() -> Vec { - let mut buf = vec![0u8; thread_rng().gen_range(17..214)]; - thread_rng().fill(buf.as_mut_slice()); - buf -} - -/// Returns a random local and public destination socket address. -/// If `ipv4` is true, then IPv4 addresses will be returned. Otherwise, IPv6 addresses will be -/// returned. -fn random_socket_addrs(ipv4: bool) -> (SocketAddr, SocketAddr) { - loop { - let rand_port = thread_rng().gen(); - let (local_addr, rand_dest_addr) = if ipv4 { - let mut ipv4_bytes = [0u8; 4]; - thread_rng().fill(&mut ipv4_bytes); - ( - SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0), - SocketAddr::new(IpAddr::from(ipv4_bytes), rand_port), - ) - } else { - let mut ipv6_bytes = [0u8; 16]; - thread_rng().fill(&mut ipv6_bytes); - ( - SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), - SocketAddr::new(IpAddr::from(ipv6_bytes), rand_port), - ) - }; - - // TODO: once https://github.com/rust-lang/rust/issues/27709 is resolved, please use - // `is_global()` to check if a new address should be attempted. - if !is_public_ip(rand_dest_addr.ip()) { - continue; - } - - return (local_addr, rand_dest_addr); - } -} - -fn is_public_ip(addr: IpAddr) -> bool { - match addr { - IpAddr::V4(ipv4) => { - // 0.x.x.x is not a publicly routable address - if ipv4.octets()[0] == 0u8 { - return false; - } - } - IpAddr::V6(ipv6) => { - if ipv6.segments()[0] == 0u16 { - return false; - } - } - } - // A non-exhaustive list of non-public subnets - let publicly_unroutable_subnets: Vec = vec![ - // IPv4 local networks - "10.0.0.0/8".parse().unwrap(), - "172.16.0.0/12".parse().unwrap(), - "192.168.0.0/16".parse().unwrap(), - // IPv4 non-forwardable network - "169.254.0.0/16".parse().unwrap(), - "192.0.0.0/8".parse().unwrap(), - // Documentation networks - "192.0.2.0/24".parse().unwrap(), - "198.51.100.0/24".parse().unwrap(), - "203.0.113.0/24".parse().unwrap(), - // IPv6 publicly unroutable networks - "fc00::/7".parse().unwrap(), - "fe80::/10".parse().unwrap(), - ]; - - !publicly_unroutable_subnets - .iter() - .any(|net| net.contains(addr)) -} diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 7c7637cd20c2..9060787536db 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -138,6 +138,7 @@ impl ConnectingState { &shared_values.route_manager, retry_attempt, ); + let params = connecting_state.tunnel_parameters.clone(); ( Box::new(connecting_state), diff --git a/talpid-routing/Cargo.toml b/talpid-routing/Cargo.toml index 14f30b83338a..b4d3e2a747d9 100644 --- a/talpid-routing/Cargo.toml +++ b/talpid-routing/Cargo.toml @@ -16,10 +16,11 @@ futures = { workspace = true } ipnetwork = { workspace = true } log = { workspace = true } tokio = { workspace = true, features = ["process", "rt-multi-thread", "net", "io-util", "time"] } - -[target.'cfg(not(target_os="android"))'.dependencies] talpid-types = { path = "../talpid-types" } +[target.'cfg(target_os = "android")'.dependencies] +jnix = { version = "0.5.2", features = ["derive"] } + [target.'cfg(target_os = "linux")'.dependencies] libc = "0.2" rtnetlink = "0.11" diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs index b80f96ccdc56..89effbdd466b 100644 --- a/talpid-routing/src/lib.rs +++ b/talpid-routing/src/lib.rs @@ -24,7 +24,10 @@ mod imp; use netlink_packet_route::rtnl::constants::RT_TABLE_MAIN; #[cfg(target_os = "macos")] -pub use imp::{imp::RouteError, DefaultRouteEvent, PlatformError}; +pub use imp::{ + imp::{DefaultRouteEvent, RouteError}, + PlatformError, +}; pub use imp::{Error, RouteManagerHandle}; @@ -70,6 +73,7 @@ pub struct Gateway { } /// A network route with a specific network node, destination and an optional metric. +#[cfg(not(target_os = "android"))] #[derive(Debug, Hash, Eq, PartialEq, Clone)] pub struct Route { node: Node, @@ -81,8 +85,14 @@ pub struct Route { mtu: Option, } +/// A network route with a specific network node, destination and an optional metric. +#[cfg(target_os = "android")] +#[derive(Debug, Hash, Eq, PartialEq, Clone)] +pub struct Route(IpNetwork); + impl Route { /// Construct a new Route + #[cfg(not(target_os = "android"))] pub fn new(node: Node, prefix: IpNetwork) -> Self { Self { node, @@ -95,6 +105,12 @@ impl Route { } } + /// Construct a new Route + #[cfg(target_os = "android")] + pub fn new(prefix: IpNetwork) -> Self { + Self(prefix) + } + #[cfg(target_os = "linux")] fn table(mut self, new_id: u32) -> Self { self.table_id = new_id; @@ -102,11 +118,13 @@ impl Route { } /// Returns the network node of the route. + #[cfg(target_os = "linux")] pub fn get_node(&self) -> &Node { &self.node } } +#[cfg(target_os = "linux")] impl fmt::Display for Route { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{} via {}", self.prefix, self.node)?; @@ -123,9 +141,22 @@ impl fmt::Display for Route { } } +#[cfg(target_os = "android")] +impl From<&talpid_types::android::RouteInfo> for Route { + fn from(route_info: &talpid_types::android::RouteInfo) -> Self { + let network = IpNetwork::new( + route_info.destination.address, + route_info.destination.prefix_length as u8, + ) + .unwrap(); + Self::new(network) + } +} + /// A network route that should be applied by the route manager. /// It can either be routed through a specific network node or it can be routed through the current /// default route. +#[cfg(not(target_os = "android"))] #[derive(Debug, Hash, Eq, PartialEq, Clone)] pub struct RequiredRoute { /// Route's prefix @@ -139,6 +170,7 @@ pub struct RequiredRoute { mtu: Option, } +#[cfg(not(target_os = "android"))] impl RequiredRoute { /// Constructs a new required route. pub fn new(prefix: IpNetwork, node: impl Into) -> Self { diff --git a/talpid-routing/src/unix/android.rs b/talpid-routing/src/unix/android.rs index 8abb23859bfd..be9f8b7d6ade 100644 --- a/talpid-routing/src/unix/android.rs +++ b/talpid-routing/src/unix/android.rs @@ -1,37 +1,232 @@ -use crate::imp::RouteManagerCommand; -use futures::{channel::mpsc, stream::StreamExt}; +use std::collections::HashSet; +use std::ops::{ControlFlow, Not}; +use std::sync::Mutex; + +use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use futures::channel::oneshot; +use futures::future::FutureExt; +use futures::select_biased; +use futures::stream::StreamExt; +use jnix::jni::objects::JValue; +use jnix::jni::{objects::JObject, JNIEnv}; +use jnix::{FromJava, JnixEnv}; + +use talpid_types::android::{AndroidContext, NetworkState}; + +use crate::{imp::RouteManagerCommand, Route}; /// Stub error type for routing errors on Android. +/// Errors that occur while setting up VpnService tunnel. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Timed out when waiting for network routes. + #[error("Timed out when waiting for network routes")] + RoutesTimedOut, +} + +/// Internal errors that may only happen during the initial poll for [NetworkState]. #[derive(Debug, thiserror::Error)] -#[error("Failed to send shutdown result")] -pub struct Error; +enum JvmError { + #[error("Failed to attach Java VM to tunnel thread")] + AttachJvmToThread(#[source] jnix::jni::errors::Error), + #[error("Failed to call Java method {0}")] + CallMethod(&'static str, #[source] jnix::jni::errors::Error), + #[error("Failed to create global reference to Java object")] + CreateGlobalRef(#[source] jnix::jni::errors::Error), + #[error("Received an invalid result from {0}.{1}: {2}")] + InvalidMethodResult(&'static str, &'static str, String), +} + +/// The sender used by [Java_net_mullvad_talpid_ConnectivityListener_notifyDefaultNetworkChange] +/// to notify the route manager of changes to the network. +static ROUTE_UPDATES_TX: Mutex>>> = Mutex::new(None); + +/// Android route manager actor. +#[derive(Debug)] +pub struct RouteManagerImpl { + /// The receiving channel for updates on changes to the network. + network_state_updates: UnboundedReceiver>, -/// Stub route manager for Android -pub struct RouteManagerImpl {} + /// Cached [NetworkState]. If no update events have been received yet, this value will be [None]. + last_state: Option, + + /// Clients waiting on response to [RouteManagerCommand::WaitForRoutes]. + waiting_for_routes: Vec>, +} impl RouteManagerImpl { #[allow(clippy::unused_async)] - pub async fn new() -> Result { - Ok(RouteManagerImpl {}) + pub async fn new(android_context: AndroidContext) -> Result { + // Create a channel between the kotlin client and route manager + let (tx, rx) = futures::channel::mpsc::unbounded(); + + *ROUTE_UPDATES_TX.lock().unwrap() = Some(tx); + + // Try to poll for the current network state at startup. + // This will most likely be null, but it covers the edge case where a NetworkState + // update has been emitted before we anyone starts to listen for route updates some + // time in the future (when connecting). + let last_state = match current_network_state(android_context) { + Ok(initial_state) => initial_state, + Err(err) => { + log::error!("Failed while polling for initial NetworkState"); + log::error!("{err}"); + None + } + }; + + let route_manager = RouteManagerImpl { + network_state_updates: rx, + last_state, + waiting_for_routes: Default::default(), + }; + + Ok(route_manager) } pub(crate) async fn run( - self, + mut self, manage_rx: mpsc::UnboundedReceiver, ) -> Result<(), Error> { let mut manage_rx = manage_rx.fuse(); - while let Some(command) = manage_rx.next().await { - match command { - RouteManagerCommand::Shutdown(tx) => { - tx.send(()).map_err(|()| Error)?; - break; + + loop { + select_biased! { + command = manage_rx.next().fuse() => { + let Some(command) = command else { break }; + if self.handle_command(command).is_break() { + break; + } } - RouteManagerCommand::AddRoutes(_routes, tx) => { - let _ = tx.send(Ok(())); + + network_state_update = self.network_state_updates.next().fuse() => { + // None means that the sender was dropped + let Some(network_state) = network_state_update else { break }; + // update the last known NetworkState + self.last_state = network_state; + + if has_routes(self.last_state.as_ref()) { + // notify waiting clients that routes exist + for client in self.waiting_for_routes.drain(..) { + let _ = client.send(()); + } + } } - RouteManagerCommand::ClearRoutes => (), } } + + log::debug!("RouteManager exited"); + Ok(()) } + + fn handle_command(&mut self, command: RouteManagerCommand) -> ControlFlow<()> { + match command { + RouteManagerCommand::Shutdown(tx) => { + let _ = tx.send(()); + return ControlFlow::Break(()); + } + RouteManagerCommand::WaitForRoutes(response_tx) => { + // check if routes have already been configured on the Android system. + // otherwise, register a listener for network state changes. + // routes may come in at any moment in the future. + if has_routes(self.last_state.as_ref()) { + let _ = response_tx.send(()); + } else { + self.waiting_for_routes.push(response_tx); + } + } + } + + ControlFlow::Continue(()) + } +} + +/// Check whether the [NetworkState] contains any routes. +/// +/// Since we are the ones telling Android what routes to set, we make the assumption that: +/// If any routes exist whatsoever, they are the the routes we specified. +fn has_routes(state: Option<&NetworkState>) -> bool { + let Some(network_state) = state else { + return false; + }; + configured_routes(network_state).is_empty().not() +} + +fn configured_routes(state: &NetworkState) -> HashSet { + match &state.routes { + None => Default::default(), + Some(route_info) => route_info.iter().map(Route::from).collect(), + } +} + +/// Entry point for Android Java code to notify the current default network state. +#[no_mangle] +#[allow(non_snake_case)] +pub extern "system" fn Java_net_mullvad_talpid_ConnectivityListener_notifyDefaultNetworkChange( + env: JNIEnv<'_>, + _: JObject<'_>, + network_state: JObject<'_>, +) { + let env = JnixEnv::from(env); + + let network_state: Option = FromJava::from_java(&env, network_state); + + let Some(tx) = &*ROUTE_UPDATES_TX.lock().unwrap() else { + // No sender has been registered + log::error!("Received routes notification wíth no channel"); + return; + }; + + log::trace!("Received network state update {:#?}", network_state); + + if tx.unbounded_send(network_state).is_err() { + log::warn!("Failed to send offline change event"); + } +} + +/// Return the current NetworkState according to Android +fn current_network_state( + android_context: AndroidContext, +) -> Result, JvmError> { + let env = JnixEnv::from( + android_context + .jvm + .attach_current_thread_as_daemon() + .map_err(JvmError::AttachJvmToThread)?, + ); + + let result = env + .call_method( + android_context.vpn_service.as_obj(), + "getConnectivityListener", + "()Lnet/mullvad/talpid/ConnectivityListener;", + &[], + ) + .map_err(|cause| JvmError::CallMethod("getConnectivityListener", cause))?; + + let connectivity_listener = match result { + JValue::Object(object) => env + .new_global_ref(object) + .map_err(JvmError::CreateGlobalRef)?, + value => { + return Err(JvmError::InvalidMethodResult( + "MullvadVpnService", + "getConnectivityListener", + format!("{:?}", value), + )) + } + }; + + let network_state = env + .call_method( + connectivity_listener.as_obj(), + "getCurrentDefaultNetworkState", + "()Lnet/mullvad/talpid/model/NetworkState;", + &[], + ) + .map_err(|cause| JvmError::CallMethod("getCurrentDefaultNetworkState", cause))?; + + let network_state: Option = FromJava::from_java(&env, network_state); + Ok(network_state) } diff --git a/talpid-routing/src/unix/linux.rs b/talpid-routing/src/unix/linux.rs index 92b4513301d3..a43f0690bc3b 100644 --- a/talpid-routing/src/unix/linux.rs +++ b/talpid-routing/src/unix/linux.rs @@ -86,6 +86,7 @@ pub type Result = std::result::Result; /// Errors that can happen in the Linux routing integration #[derive(thiserror::Error, Debug)] +#[allow(missing_docs)] pub enum Error { #[error("Failed to open a netlink connection")] Connect(#[source] io::Error), diff --git a/talpid-routing/src/unix/macos/mod.rs b/talpid-routing/src/unix/macos/mod.rs index 85a020ba797f..df89767e38fa 100644 --- a/talpid-routing/src/unix/macos/mod.rs +++ b/talpid-routing/src/unix/macos/mod.rs @@ -16,9 +16,10 @@ use std::{ use talpid_types::ErrorExt; use watch::RoutingTable; -use super::{DefaultRouteEvent, RouteManagerCommand}; +use super::RouteManagerCommand; use data::{Destination, RouteDestination, RouteMessage, RouteSocketMessage}; +pub use super::DefaultRouteEvent; pub use interface::DefaultRoute; mod data; diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 34d2570137c6..042360d52027 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -1,18 +1,23 @@ -#[cfg(target_os = "linux")] -use crate::Route; #[cfg(target_os = "macos")] pub use crate::{imp::imp::DefaultRoute, Gateway}; +#[cfg(any(target_os = "linux", target_os = "macos"))] use super::RequiredRoute; +#[cfg(target_os = "linux")] +use super::Route; use futures::channel::{ mpsc::{self, UnboundedSender}, oneshot, }; -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; +#[cfg(target_os = "android")] +use talpid_types::android::AndroidContext; #[cfg(any(target_os = "linux", target_os = "macos"))] use futures::stream::Stream; +#[cfg(any(target_os = "linux", target_os = "macos"))] +use std::collections::HashSet; #[cfg(target_os = "linux")] use std::net::IpAddr; @@ -32,6 +37,7 @@ mod imp; #[path = "android.rs"] mod imp; +#[cfg(any(target_os = "macos", target_os = "linux"))] pub use imp::Error as PlatformError; /// Errors that can be encountered whilst interacting with a [RouteManagerHandle]. @@ -97,11 +103,7 @@ pub(crate) enum RouteManagerCommand { #[cfg(target_os = "android")] #[derive(Debug)] pub(crate) enum RouteManagerCommand { - AddRoutes( - HashSet, - oneshot::Sender>, - ), - ClearRoutes, + WaitForRoutes(oneshot::Sender<()>), Shutdown(oneshot::Sender<()>), } @@ -165,6 +167,7 @@ impl RouteManagerHandle { pub async fn spawn( #[cfg(target_os = "linux")] fwmark: u32, #[cfg(target_os = "linux")] table_id: u32, + #[cfg(target_os = "android")] android_context: AndroidContext, ) -> Result { let (manage_tx, manage_rx) = mpsc::unbounded(); let manage_tx = Arc::new(manage_tx); @@ -175,6 +178,8 @@ impl RouteManagerHandle { table_id, #[cfg(target_os = "macos")] Arc::downgrade(&manage_tx), + #[cfg(target_os = "android")] + android_context, ) .await?; tokio::spawn(manager.run(manage_rx)); @@ -192,6 +197,7 @@ impl RouteManagerHandle { } /// Applies the given routes until they are cleared + #[cfg(not(target_os = "android"))] pub async fn add_routes(&self, routes: HashSet) -> Result<(), Error> { let (result_tx, result_rx) = oneshot::channel(); self.tx @@ -204,13 +210,43 @@ impl RouteManagerHandle { .map_err(Error::PlatformError) } + /// Wait for routes to come up. + /// + /// This function is guaranteed to *not* wait for longer than 2 seconds. + /// Please, see the implementation of this function for further details. + #[cfg(target_os = "android")] + pub async fn wait_for_routes(&self) -> Result<(), Error> { + use std::time::Duration; + use tokio::time::timeout; + /// Maximum time to wait for routes to come up. The expected mean time is low (~200 ms), but + /// we add some additional margin to give some slack to slower hardware primarily. + const WAIT_FOR_ROUTES_TIMEOUT: Duration = Duration::from_secs(2); + + let (result_tx, result_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::WaitForRoutes(result_tx)) + .map_err(|_| Error::RouteManagerDown)?; + + timeout(WAIT_FOR_ROUTES_TIMEOUT, result_rx) + .await + .map_err(|_error| Error::PlatformError(imp::Error::RoutesTimedOut))? + .map_err(|_| Error::ManagerChannelDown) + } + /// Removes all routes previously applied in [`RouteManagerHandle::add_routes`]. + #[cfg(not(target_os = "android"))] pub fn clear_routes(&self) -> Result<(), Error> { self.tx .unbounded_send(RouteManagerCommand::ClearRoutes) .map_err(|_| Error::RouteManagerDown) } + /// (Android) This is a noop since we don't directly control the routes on Android. + #[cfg(target_os = "android")] + pub fn clear_routes(&self) -> Result<(), Error> { + Ok(()) + } + /// Listen for non-tunnel default route changes. #[cfg(target_os = "macos")] pub async fn default_route_listener( diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs index 3d356e50d328..f285b4a64ca1 100644 --- a/talpid-tunnel/src/tun_provider/android/mod.rs +++ b/talpid-tunnel/src/tun_provider/android/mod.rs @@ -46,6 +46,9 @@ pub enum Error { #[error("Failed to create tunnel device")] TunnelDeviceError, + #[error("Routes timed out")] + RoutesTimedOut, + #[error("Profile for VPN has not been setup")] NotPrepared, @@ -381,7 +384,7 @@ impl AsRawFd for VpnServiceTun { enum CreateTunResult { Success { tun_fd: i32 }, InvalidDnsServers { addresses: Vec }, - TunnelDeviceError, + EstablishError, OtherLegacyAlwaysOnVpn, OtherAlwaysOnApp { app_name: String }, NotPrepared, @@ -394,7 +397,7 @@ impl From for Result { CreateTunResult::InvalidDnsServers { addresses } => { Err(Error::InvalidDnsServers(addresses)) } - CreateTunResult::TunnelDeviceError => Err(Error::TunnelDeviceError), + CreateTunResult::EstablishError => Err(Error::TunnelDeviceError), CreateTunResult::OtherLegacyAlwaysOnVpn => Err(Error::OtherLegacyAlwaysOnVpn), CreateTunResult::OtherAlwaysOnApp { app_name } => { Err(Error::OtherAlwaysOnApp { app_name }) diff --git a/talpid-types/src/android/mod.rs b/talpid-types/src/android/mod.rs index 4169216f3b5e..c5484aaafdfd 100644 --- a/talpid-types/src/android/mod.rs +++ b/talpid-types/src/android/mod.rs @@ -1,8 +1,61 @@ +use ipnetwork::{IpNetwork, IpNetworkError, Ipv4Network, Ipv6Network}; use jnix::jni::{objects::GlobalRef, JavaVM}; +use jnix::{FromJava, IntoJava}; +use std::net::IpAddr; use std::sync::Arc; +/// What Java calls an [IpAddr] +pub type InetAddress = IpAddr; + #[derive(Clone)] pub struct AndroidContext { pub jvm: Arc, pub vpn_service: GlobalRef, } + +/// A Java-compatible variant of [IpNetwork] +#[derive(Clone, Debug, Eq, PartialEq, Hash, IntoJava, FromJava)] +#[jnix(package = "net.mullvad.talpid.model")] +pub struct InetNetwork { + pub address: IpAddr, + pub prefix_length: i16, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash, IntoJava, FromJava)] +#[jnix(package = "net.mullvad.talpid.model")] +pub struct RouteInfo { + pub destination: InetNetwork, + pub gateway: Option, + pub interface: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash, IntoJava, FromJava)] +#[jnix(package = "net.mullvad.talpid.model")] +pub struct NetworkState { + pub network_handle: i64, + pub routes: Option>, + pub dns_servers: Option>, +} + +impl From for InetNetwork { + fn from(ip_network: IpNetwork) -> Self { + InetNetwork { + address: ip_network.ip(), + prefix_length: ip_network.prefix() as i16, + } + } +} + +impl TryFrom for IpNetwork { + type Error = IpNetworkError; + fn try_from(inet_network: InetNetwork) -> Result { + Ok(match inet_network.address { + IpAddr::V4(addr) => { + IpNetwork::V4(Ipv4Network::new(addr, inet_network.prefix_length as u8)?) + } + IpAddr::V6(addr) => { + IpNetwork::V6(Ipv6Network::new(addr, inet_network.prefix_length as u8)?) + } + }) + } +} diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index af8e2da79e82..fe1a848e9a74 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -403,7 +403,6 @@ impl WireguardMonitor { let desired_mtu = get_desired_mtu(params); let mut config = Config::from_parameters(params, desired_mtu).map_err(Error::WireguardConfigError)?; - let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel(); // Start obfuscation server and patch the WireGuard config to point the endpoint to it. let obfuscator = args @@ -466,6 +465,13 @@ impl WireguardMonitor { .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; + // Wait for routes to come up + args.route_manager + .wait_for_routes() + .await + .map_err(Error::SetupRoutingError) + .map_err(CloseMsg::SetupError)?; + if should_negotiate_ephemeral_peer { let ephemeral_obfs_sender = close_obfs_sender.clone();