Skip to content

Commit

Permalink
Fix crash with multiple resumes in VpnBackend.resetSockets
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusz-markowicz committed Jun 10, 2024
1 parent af441f4 commit a07cee9
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 16 deletions.
25 changes: 25 additions & 0 deletions app/src/main/java/com/protonvpn/android/utils/FlowUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
package com.protonvpn.android.utils

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.ChannelResult
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.flatMapLatest
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.withTimeoutOrNull
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds

Expand Down Expand Up @@ -82,3 +87,23 @@ fun tickFlow(step: Duration, clock: () -> Long) = flow {
lastTick = now
}
}

// Util serving as a safer alternative for suspendCancellableCoroutine (where resume throws exception
// when called after cancellation and need to be guarded with isActive).
suspend fun <T> suspendForCallback(
onClose: () -> Unit,
registerCallback: (resume: (T) -> ChannelResult<Unit>) -> Unit
): T? =
callbackFlow {
registerCallback { trySend(it) }
awaitClose { onClose() }
}.first()

suspend fun <T> suspendForCallbackWithTimeout(
timeoutMs: Long,
onClose: () -> Unit,
registerCallback: (resume: (T) -> ChannelResult<Unit>) -> Unit
): T? =
withTimeoutOrNull(timeoutMs) {
suspendForCallback(onClose, registerCallback)
}
24 changes: 10 additions & 14 deletions app/src/main/java/com/protonvpn/android/vpn/VpnBackend.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import com.protonvpn.android.ui.home.GetNetZone
import com.protonvpn.android.utils.Constants
import com.protonvpn.android.utils.Storage
import com.protonvpn.android.utils.SyncStateFlow
import com.protonvpn.android.utils.suspendForCallbackWithTimeout
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Job
Expand All @@ -60,15 +61,13 @@ import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.coroutines.yield
import me.proton.core.network.data.di.SharedOkHttpClient
import me.proton.core.network.domain.NetworkManager
import me.proton.core.network.domain.NetworkStatus
import okhttp3.OkHttpClient
import kotlin.coroutines.resume

data class PrepareResult(val backend: VpnBackend, val connectionParams: ConnectionParams) : java.io.Serializable

Expand Down Expand Up @@ -546,19 +545,16 @@ private suspend fun OkHttpClient.resetSockets() {
// Cancel all running calls
dispatcher.cancelAll()

val timedOut = null == withTimeoutOrNull(500) {
suspendCancellableCoroutine { continuation ->
val original = dispatcher.idleCallback?.unwrapIdleCallback()
dispatcher.idleCallback =
OkHttpIdleCallbackWrapper(original) {
if (continuation.isActive) { // It can be cancelled by the time this gets executed.
continuation.resume(Unit)
dispatcher.idleCallback = original
}
}
continuation.invokeOnCancellation { dispatcher.idleCallback = original }
val original = dispatcher.idleCallback?.unwrapIdleCallback()
val timedOut = null == suspendForCallbackWithTimeout(
500,
onClose = { dispatcher.idleCallback = original },
registerCallback = { resume ->
dispatcher.idleCallback = OkHttpIdleCallbackWrapper(original) {
resume(Unit)
}
}
}
)
if (timedOut)
ProtonLogger.log(ConnStateChanged, "Tunnel opened: timed-out waiting for OkHttp idle")
}
Expand Down
32 changes: 32 additions & 0 deletions app/src/test/java/com/protonvpn/app/utils/FlowUtilsTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,26 @@
package com.protonvpn.app.utils

import com.protonvpn.android.utils.mapState
import com.protonvpn.android.utils.suspendForCallback
import com.protonvpn.android.utils.suspendForCallbackWithTimeout
import com.protonvpn.android.utils.tickFlow
import com.protonvpn.android.utils.withPrevious
import com.protonvpn.test.shared.runWhileCollecting
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.currentTime
import kotlinx.coroutines.test.runCurrent
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
import org.junit.Test
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.seconds

@OptIn(ExperimentalCoroutinesApi::class)
Expand Down Expand Up @@ -83,4 +88,31 @@ class FlowUtilsTests {
}
assertEquals(listOf(0L, 1000L, 2000L, 3000L), timestamps)
}

@Test
fun `suspendForCallback waits for callback`() = runTest {
var closed = false
val result = suspendForCallback(onClose = { closed = true }) { resume ->
launch {
delay(100)
resume(1)
resume(2) // This should get ignored
}
}
assertEquals(1, result)
assertTrue(closed)
}

@Test
fun suspendForCallbackWithTimeout() = runTest {
var closed = false
val result = suspendForCallbackWithTimeout(500, onClose = { closed = true }) { resume ->
launch {
delay(1000)
resume(1)
}
}
assertEquals(null, result)
assertTrue(closed)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@

package com.protonvpn.test.shared

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.TestScope

@OptIn(ExperimentalCoroutinesApi::class)
fun <T> TestScope.runWhileCollecting(flow: Flow<T>, block: suspend () -> Unit): List<T> {
val collectedValues = mutableListOf<T>()
val collectJob = backgroundScope.launch {
Expand Down

0 comments on commit a07cee9

Please sign in to comment.