Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Disconnect/re-connect from notification when both VPN and AppTP are enabled #5154

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ interface Vpn {
suspend fun start()

/**
* Disable the device VPN by stopping the VPN service
* Pauses the VPN tunnel.
* All features that were registered to use the VPN tunnel (eg. AppTP, NetP) continue to be registered and so a subsequent
* [start] call will re-enable them all
*/
suspend fun pause()

/**
* Stops the VPN tunnel AND all features registered to use the VPN tunnel (eg. AppTP, NetP). A subsequent call to [start]
* will not re-start the VPN because no feature would be registered.
*/
suspend fun stop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,33 @@ import android.content.SharedPreferences
import androidx.core.content.edit
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.data.store.api.SharedPreferencesProvider
import com.duckduckgo.di.scopes.AppScope
import com.duckduckgo.mobile.android.vpn.service.TrackerBlockingVpnService
import com.squareup.anvil.annotations.ContributesBinding
import dagger.SingleInstanceIn
import java.util.UUID
import javax.inject.Inject
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.withContext
import logcat.logcat

private const val PREFS_FILENAME = "com.duckduckgo.mobile.android.vpn.feature.registry.v1"
private const val IS_INITIALIZED = "IS_INITIALIZED"

internal class VpnFeaturesRegistryImpl(
@ContributesBinding(
scope = AppScope::class,
boundType = VpnFeaturesRegistry::class,
)
@ContributesBinding(
scope = AppScope::class,
boundType = Vpn::class,
)
@SingleInstanceIn(AppScope::class)
class VpnFeaturesRegistryImpl @Inject constructor(
private val vpnServiceWrapper: VpnServiceWrapper,
private val sharedPreferencesProvider: SharedPreferencesProvider,
private val dispatcherProvider: DispatcherProvider,
) : VpnFeaturesRegistry {
) : VpnFeaturesRegistry, Vpn {

private val mutex = Mutex()

Expand Down Expand Up @@ -103,6 +116,34 @@ internal class VpnFeaturesRegistryImpl(
private fun registeredFeatures(): Map<String, Any?> {
return preferences.all.filter { it.key != IS_INITIALIZED }
}

override suspend fun start() = withContext(dispatcherProvider.io()) {
vpnServiceWrapper.startService()
}

override suspend fun pause() {
vpnServiceWrapper.stopService()
}

override suspend fun stop() {
try {
mutex.lock()
// unregister all features
getRegisteredFeatures().onEach {
preferences.edit(commit = true) {
remove(it.featureName)
}
}
// stop VPN
vpnServiceWrapper.stopService()
} finally {
mutex.unlock()
}
}

override suspend fun snooze(triggerAtMillis: Long) {
vpnServiceWrapper.snoozeService(triggerAtMillis)
}
}

/**
Expand All @@ -111,10 +152,9 @@ internal class VpnFeaturesRegistryImpl(
*
* The class is marked as open to be able to mock it in tests.
*/
internal open class VpnServiceWrapper(
open class VpnServiceWrapper @Inject constructor(
private val context: Context,
private val dispatcherProvider: DispatcherProvider,
) : Vpn {
) {
open fun restartVpnService(forceRestart: Boolean) {
TrackerBlockingVpnService.restartVpnService(context, forceRestart = forceRestart)
}
Expand All @@ -127,19 +167,11 @@ internal open class VpnServiceWrapper(
TrackerBlockingVpnService.startService(context)
}

open fun isServiceRunning(): Boolean {
return TrackerBlockingVpnService.isServiceRunning(context)
}

override suspend fun start() = withContext(dispatcherProvider.io()) {
startService()
}

override suspend fun stop() = withContext(dispatcherProvider.io()) {
stopService()
open fun snoozeService(triggerAtMillis: Long) {
TrackerBlockingVpnService.snoozeService(context, triggerAtMillis)
}

override suspend fun snooze(triggerAtMillis: Long) {
TrackerBlockingVpnService.snoozeService(context, triggerAtMillis)
open fun isServiceRunning(): Boolean {
return TrackerBlockingVpnService.isServiceRunning(context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ import android.content.res.Resources
import android.net.ConnectivityManager
import androidx.room.Room
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.data.store.api.SharedPreferencesProvider
import com.duckduckgo.di.scopes.AppScope
import com.duckduckgo.mobile.android.vpn.Vpn
import com.duckduckgo.mobile.android.vpn.VpnFeaturesRegistry
import com.duckduckgo.mobile.android.vpn.VpnFeaturesRegistryImpl
import com.duckduckgo.mobile.android.vpn.VpnServiceWrapper
import com.duckduckgo.mobile.android.vpn.stats.AppTrackerBlockingStatsRepository
import com.duckduckgo.mobile.android.vpn.stats.RealAppTrackerBlockingStatsRepository
import com.duckduckgo.mobile.android.vpn.store.*
Expand Down Expand Up @@ -88,25 +83,6 @@ object VpnAppModule {
return context.resources
}

@Provides
@SingleInstanceIn(AppScope::class)
fun provideVpnFeaturesRegistry(
context: Context,
sharedPreferencesProvider: SharedPreferencesProvider,
dispatcherProvider: DispatcherProvider,
): VpnFeaturesRegistry {
return VpnFeaturesRegistryImpl(VpnServiceWrapper(context, dispatcherProvider), sharedPreferencesProvider, dispatcherProvider)
}

@Provides
@SingleInstanceIn(AppScope::class)
fun provideVpnServiceWrapper(
context: Context,
dispatcherProvider: DispatcherProvider,
): Vpn {
return VpnServiceWrapper(context, dispatcherProvider)
}

@Provides
fun provideAppTrackerBlockingStatsRepository(
vpnDatabase: VpnDatabase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import android.os.SystemClock
import com.duckduckgo.anvil.annotations.InjectWith
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.di.scopes.ReceiverScope
import com.duckduckgo.mobile.android.app.tracking.AppTrackingProtection
import com.duckduckgo.mobile.android.vpn.Vpn
import com.duckduckgo.mobile.android.vpn.pixels.DeviceShieldPixels
import com.duckduckgo.networkprotection.api.NetworkProtectionState
import dagger.android.AndroidInjection
import java.util.concurrent.TimeUnit
import javax.inject.Inject
Expand All @@ -38,6 +40,10 @@ import logcat.logcat
class VpnActionReceiver : BroadcastReceiver() {
@Inject lateinit var vpn: Vpn

@Inject lateinit var networkProtectionState: NetworkProtectionState
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIt: Can be removed as they are not used anymore in this class right?


@Inject lateinit var appTrackingProtection: AppTrackingProtection

@Inject lateinit var dispatcherProvider: DispatcherProvider

@Inject lateinit var context: Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.squareup.anvil.annotations.ContributesMultibinding
import dagger.SingleInstanceIn
import javax.inject.Inject
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.runBlocking

Expand Down Expand Up @@ -77,6 +78,7 @@ class AppTPAndNetPEnabledNotificationContentPlugin @Inject constructor(
}

return repository.getVpnTrackers({ dateOfLastHour() })
.filter { isActive() } // make sure we only emit when this plugin is active
.map { trackersBlocked ->
val trackingApps = trackersBlocked.trackingApps()
val location = networkProtectionState.serverLocation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class AppTpEnabledNotificationContentPlugin @Inject constructor(
}

return repository.getVpnTrackers({ dateOfLastHour() })
.filter { isActive() } // make sure we only emit when this plugin is active
.map { trackersBlocked ->
val trackingApps = trackersBlocked.trackingApps()
val isEnabled = appTrackingProtection.isEnabled()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import com.duckduckgo.common.test.CoroutineTestRule
import com.duckduckgo.common.test.api.InMemorySharedPreferences
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.data.store.api.SharedPreferencesProvider
import kotlinx.coroutines.test.runTest
import org.junit.Assert.*
Expand All @@ -44,7 +43,7 @@ class VpnFeaturesRegistryImplTest {
@Before
fun setup() {
val prefs = InMemorySharedPreferences()
vpnServiceWrapper = TestVpnServiceWrapper(coroutineTestRule.testDispatcherProvider)
vpnServiceWrapper = TestVpnServiceWrapper()

whenever(
sharedPreferencesProvider.getSharedPreferences(eq("com.duckduckgo.mobile.android.vpn.feature.registry.v1"), eq(true), eq(false)),
Expand Down Expand Up @@ -163,9 +162,7 @@ class VpnFeaturesRegistryImplTest {
BAR("BAR"),
}

private class TestVpnServiceWrapper constructor(
dispatcher: DispatcherProvider,
) : VpnServiceWrapper(InstrumentationRegistry.getInstrumentation().context, dispatcher) {
private class TestVpnServiceWrapper constructor() : VpnServiceWrapper(InstrumentationRegistry.getInstrumentation().context) {
private var isRunning = false
var restartCount = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentThenReturnsCorrectInitialUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn("Stockholm, SE")
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
val item = awaitItem()
Expand All @@ -151,6 +153,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentOneCompanyThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn("Stockholm, SE")
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
val trackers = listOf(aTrackerAndCompany())
Expand All @@ -169,6 +173,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentMultipleDifferentAppsThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn("Stockholm, SE")
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
val trackers = listOf(
Expand All @@ -193,6 +199,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentMultipleDifferentAppsNoLocationThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn(null)
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
val trackers = listOf(
Expand All @@ -217,6 +225,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentTrackersWithoutEntityThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn("Stockholm, SE")
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
appTrackerBlockingStatsRepository.insert(listOf(aTrackerAndCompany(), aTrackerAndCompany()))
Expand All @@ -232,6 +242,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentTrackersWithoutEntityNoLocationThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn(null)
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
appTrackerBlockingStatsRepository.insert(listOf(aTrackerAndCompany(), aTrackerAndCompany()))
Expand All @@ -247,6 +259,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentMultipleSameThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn("Stockholm, SE")
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
appTrackerBlockingStatsRepository.insert(listOf(aTrackerAndCompany(), aTrackerAndCompany()))
Expand All @@ -265,6 +279,8 @@ class AppTPAndNetPEnabledNotificationContentPluginTest {
@Test
fun getUpdateContentMultipleSameNoLocationThenReturnsCorrectUpdatedNotificationContent() = runTest {
whenever(networkProtectionState.serverLocation()).thenReturn(null)
whenever(appTrackingProtection.isEnabled()).thenReturn(true)
whenever(networkProtectionState.isEnabled()).thenReturn(true)

plugin.getUpdatedContent().test {
appTrackerBlockingStatsRepository.insert(listOf(aTrackerAndCompany(), aTrackerAndCompany()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,12 @@ class AppTpEnabledNotificationContentPluginTest {
}

@Test
fun getUpdateContentAppTpNotEnabledThenReturnsCorrectInitialUpdatedNotificationContent() = runTest {
fun getUpdateContentAppTpNotEnabledThenReturnsNoContent() = runTest {
whenever(appTrackingProtection.isEnabled()).thenReturn(false)
whenever(networkProtectionState.isEnabled()).thenReturn(false)

plugin.getUpdatedContent().test {
val item = awaitItem()

item.assertTextEquals("")

cancelAndConsumeRemainingEvents()
expectNoEvents()
}
}

Expand All @@ -162,20 +159,16 @@ class AppTpEnabledNotificationContentPluginTest {
}

@Test
fun getUpdateContentOneCompanyAppTpNotEnabledThenReturnsCorrectUpdatedNotificationContent() = runTest {
fun getUpdateContentOneCompanyAppTpNotEnabledThenReturnsNoContent() = runTest {
whenever(appTrackingProtection.isEnabled()).thenReturn(false)
whenever(networkProtectionState.isEnabled()).thenReturn(false)

plugin.getUpdatedContent().test {
val trackers = listOf(aTrackerAndCompany())
appTrackerBlockingStatsRepository.insert(trackers)
db.vpnAppTrackerBlockingDao().insertTrackerEntities(trackers.map { it.asEntity() })

skipItems(1)
val item = awaitItem()

item.assertTextEquals("")

cancelAndConsumeRemainingEvents()
expectNoEvents()
}
}

Expand Down Expand Up @@ -205,8 +198,9 @@ class AppTpEnabledNotificationContentPluginTest {
}

@Test
fun getUpdateContentMultipleDifferentAppsAppTpNotEnabledThenReturnsCorrectUpdatedNotificationContent() = runTest {
fun getUpdateContentMultipleDifferentAppsAppTpNotEnabledThenReturnsNoContent() = runTest {
whenever(appTrackingProtection.isEnabled()).thenReturn(false)
whenever(networkProtectionState.isEnabled()).thenReturn(false)

plugin.getUpdatedContent().test {
val trackers = listOf(
Expand All @@ -220,12 +214,7 @@ class AppTpEnabledNotificationContentPluginTest {
appTrackerBlockingStatsRepository.insert(trackers)
db.vpnAppTrackerBlockingDao().insertTrackerEntities(trackers.map { it.asEntity() })

skipItems(1)
val item = awaitItem()

item.assertTextEquals("")

cancelAndConsumeRemainingEvents()
expectNoEvents()
}
}

Expand Down Expand Up @@ -265,16 +254,13 @@ class AppTpEnabledNotificationContentPluginTest {
}

@Test
fun getUpdateContentMultipleSameAppTpNotEnabledThenReturnsCorrectUpdatedNotificationContent() = runTest {
fun getUpdateContentMultipleSameAppTpNotEnabledThenReturnsNoContent() = runTest {
whenever(appTrackingProtection.isEnabled()).thenReturn(false)
whenever(networkProtectionState.isEnabled()).thenReturn(false)
plugin.getUpdatedContent().test {
appTrackerBlockingStatsRepository.insert(listOf(aTrackerAndCompany(), aTrackerAndCompany()))

val item = expectMostRecentItem()

item.assertTextEquals("")

cancelAndConsumeRemainingEvents()
expectNoEvents()
}
}

Expand Down
Loading
Loading