如何在单元测试中使用协程测试livedata

How to test livedata with coroutine in unit test

我正在使用 mockito、junit5 和协程来获取存储库中的数据。但是在测试用例中调用了 no 方法。我尝试在没有任何 Dispatchersemit() 函数的情况下使用正常的挂起函数并且它有效。因此,我猜测原因可能是由于 livedata coroutine

GitReposRepository.kt

fun loadReposSuspend(owner: String) = liveData(Dispatchers.IO) {
    emit(Result.Loading)
    val response = githubService.getReposNormal(owner)
    val repos = response.body()!!
    if (repos.isEmpty()) {
        emit(Result.Success(repos))
        repoDao.insert(*repos.toTypedArray())
    } else {
        emitSource(repoDao.loadRepositories(owner)
                           .map { Result.Success(it) })
    }
}

GitReposRepositoryTest.kt

internal class GitRepoRepositoryTest {

    private lateinit var appExecutors:AppExecutors
    private lateinit var repoDao: RepoDao
    private lateinit var githubService: GithubService
    private lateinit var gitRepoRepository: GitRepoRepository

    @BeforeEach
    internal fun setUp() {
        appExecutors = mock(AppExecutors::class.java)
        repoDao = mock(RepoDao::class.java)
        githubService = mock(GithubService::class.java)
        gitRepoRepository = GitRepoRepository(appExecutors,
                                              repoDao,
                                              githubService)
    }

    @Test
    internal fun `should call network to fetch result and insert to db`() = runBlocking {
        //given
        val owner = "Testing"
        val response = Response.success(listOf(Repo(),Repo()))
        `when`(githubService.getReposNormal(ArgumentMatchers.anyString())).thenReturn(response)
        //when
        gitRepoRepository.loadReposSuspend(owner)
        //then
        verify(githubService).getReposNormal(owner)
        verify(repoDao).insertRepos(ArgumentMatchers.anyList())
    }
}

在网上搜索了几天后。我找到了如何在 livedata 中使用协程进行单元测试,并提出了以下想法。这可能不是最好的主意,但希望它能给有类似问题的人带来一些启发。

使用 livedata 进行协程单元测试的必要部分很少:

  1. 单元测试需要添加2条规则(Coroutine Rule, InstantExecutor Rule). If you use Junit5 like me, you should use extensions instead. Coroutine Rule provide the function for you to use the testCoroutine dispatcher in Java UnitTest. InstantExecutor Rule provide the function for you to monitor the livedata emit value in Java UnitTest. And be careful coroutine.dispatcher is the most important part for testing coroutine in Java UnitTest. It is suggested to watch the video about Coroutine testing in Kotlin https://youtu.be/KMb0Fs8rCRs

  2. 需要在Constructor中设置要注入的CoroutineDispatcher

    You should ALWAYS inject Dispatchers (https://youtu.be/KMb0Fs8rCRs?t=850)

  3. livedata 的一些 livedata 扩展,可帮助您验证实时数据中发出的值的值。

这是我的仓库(我关注android官方的recommended app architecture

GitRepoRepository.kt(这个想法来自两个来源,LegoThemeRepository, NetworkBoundResource

@Singleton
class GitRepoRepository @Inject constructor(private val appExecutors: AppExecutors,
                                            private val repoDao: RepoDao,
                                            private val githubService: GithubService,
                                            private val dispatcher: CoroutineDispatcher = Dispatchers.IO,
                                            private val repoListRateLimit: RateLimiter<String> = RateLimiter(
                                                    10,
                                                    TimeUnit.MINUTES)
) {

    fun loadRepo(owner: String
    ): LiveData<Result<List<Repo>>> = repositoryLiveData(
            localResult = { repoDao.loadRepositories(owner) },
            remoteResult = {
                transformResult { githubService.getRepo(owner) }.apply {
                    if (this is Result.Error) {
                        repoListRateLimit.reset(owner)
                    }
                }
            },
            shouldFetch = { repoListRateLimit.shouldFetch(owner) },
            saveFetchResult = { repoDao.insertRepos(it) },
            dispatcher = this.dispatcher
    )
    ...
}

GitRepoRepositoryTest.kt

@ExperimentalCoroutinesApi
@ExtendWith(InstantExecutorExtension::class)
class GitRepoRepositoryTest {

    // Set the main coroutines dispatcher for unit testing
    companion object {
        @JvmField
        @RegisterExtension
        var coroutinesRule = CoroutinesTestExtension()
    }

    private lateinit var appExecutors: AppExecutors
    private lateinit var repoDao: RepoDao
    private lateinit var githubService: GithubService
    private lateinit var gitRepoRepository: GitRepoRepository
    private lateinit var rateLimiter: RateLimiter<String>

    @BeforeEach
    fun setUp() {
        appExecutors = mock(AppExecutors::class.java)
        repoDao = mock(RepoDao::class.java)
        githubService = mock(GithubService::class.java)
        rateLimiter = mock(RateLimiter::class.java) as RateLimiter<String>
        gitRepoRepository = GitRepoRepository(appExecutors,
                                              repoDao,
                                              githubService,
                                              coroutinesRule.dispatcher,
                                              rateLimiter)
    }

    @Test
    fun `should not call network to fetch result if the process in rate limiter is not valid`() = coroutinesRule.runBlocking {
        //given
        val owner = "Tom"
        val response = Response.success(listOf(Repo(), Repo()))
        `when`(githubService.getRepo(anyString())).thenReturn(
                response)
        `when`(rateLimiter.shouldFetch(anyString())).thenReturn(false)
        //when
        gitRepoRepository.loadRepo(owner).getOrAwaitValue()
        //then
        verify(githubService, never()).getRepo(owner)
        verify(repoDao, never()).insertRepos(anyList())
    }

    @Test
    fun `should reset ratelimiter if the network response contains error`() = coroutinesRule.runBlocking {
        //given
        val owner = "Tom"
        val response = Response.error<List<Repo>>(500,
                                                  "Test Server Error".toResponseBody(
                                                          "text/plain".toMediaTypeOrNull()))
        `when`(githubService.getRepo(anyString())).thenReturn(
                response)
        `when`(rateLimiter.shouldFetch(anyString())).thenReturn(true)
        //when
        gitRepoRepository.loadRepo(owner).getOrAwaitValue()
        //then
        verify(rateLimiter, times(1)).reset(owner)
    }
}

CoroutineUtil.kt(idea也来自here如果你想记录一些信息,这里应该是自定义实现,下面的测试用例为你提供一些见解如何在协程中测试它

sealed class Result<out R> {
    data class Success<out T>(val data: T) : Result<T>()
    object Loading : Result<Nothing>()
    data class Error<T>(val message: String) : Result<T>()
    object Finish : Result<Nothing>()
}

fun <T, A> repositoryLiveData(localResult: (() -> LiveData<T>) = { MutableLiveData() },
                              remoteResult: (suspend () -> Result<A>)? = null,
                              saveFetchResult: suspend (A) -> Unit = { Unit },
                              dispatcher: CoroutineDispatcher = Dispatchers.IO,
                              shouldFetch: () -> Boolean = { true }
): LiveData<Result<T>> =
        liveData(dispatcher) {
            emit(Result.Loading)
            val source: LiveData<Result<T>> = localResult.invoke()
                    .map { Result.Success(it) }
            emitSource(source)
            try {
                remoteResult?.let {
                    if (shouldFetch.invoke()) {
                        when (val response = it.invoke()) {
                            is Result.Success -> {
                                saveFetchResult(response.data)
                            }
                            is Result.Error -> {
                                emit(Result.Error<T>(response.message))
                                emitSource(source)
                            }
                            else -> {
                            }
                        }
                    }
                }
            } catch (e: Exception) {
                emit(Result.Error<T>(e.message.toString()))
                emitSource(source)
            } finally {
                emit(Result.Finish)
            }
        }

suspend fun <T> transformResult(call: suspend () -> Response<T>): Result<T> {
    try {
        val response = call()
        if (response.isSuccessful) {
            val body = response.body()
            if (body != null) return Result.Success(body)
        }
        return error(" ${response.code()} ${response.message()}")
    } catch (e: Exception) {
        return error(e.message ?: e.toString())
    }
}

fun <T> error(message: String): Result<T> {
    return Result.Error("Network call has failed for a following reason: $message")
}

CoroutineUtilKtTest.kt

interface Delegation {
    suspend fun remoteResult(): Result<String>
    suspend fun saveResult(s: String)
    fun localResult(): MutableLiveData<String>
    fun shouldFetch(): Boolean
}

fun <T> givenSuspended(block: suspend () -> T) = BDDMockito.given(runBlocking { block() })

@ExperimentalCoroutinesApi
@ExtendWith(InstantExecutorExtension::class)
class CoroutineUtilKtTest {
    // Set the main coroutines dispatcher for unit testing
    companion object {
        @JvmField
        @RegisterExtension
        var coroutinesRule = CoroutinesTestExtension()
    }

    val delegation: Delegation = mock()
    private val LOCAL_RESULT = "Local Result Fetch"
    private val REMOTE_RESULT = "Remote Result Fetch"
    private val REMOTE_CRASH = "Remote Result Crash"

    @BeforeEach
    fun setUp() {
        given { delegation.shouldFetch() }
                .willReturn(true)
        given { delegation.localResult() }
                .willReturn(MutableLiveData(LOCAL_RESULT))
        givenSuspended { delegation.remoteResult() }
                .willReturn(Result.Success(REMOTE_RESULT))
    }

    @Test
    fun `should call local result only if the remote result should not fetch`() = coroutinesRule.runBlocking {
        //given
        given { delegation.shouldFetch() }.willReturn(false)

        //when
        repositoryLiveData<String, String>(
                localResult = { delegation.localResult() },
                remoteResult = { delegation.remoteResult() },
                shouldFetch = { delegation.shouldFetch() },
                dispatcher = coroutinesRule.dispatcher
        ).getOrAwaitValue()
        //then
        verify(delegation, times(1)).localResult()
        verify(delegation, never()).remoteResult()
    }


    @Test
    fun `should call remote result and then save result`() = coroutinesRule.runBlocking {
        //when
        repositoryLiveData<String, String>(
                shouldFetch = { delegation.shouldFetch() },
                remoteResult = { delegation.remoteResult() },
                saveFetchResult = { s -> delegation.saveResult(s) },
                dispatcher = coroutinesRule.dispatcher
        ).getOrAwaitValue()
        //then
        verify(delegation, times(1)).remoteResult()
        verify(delegation,
               times(1)).saveResult(REMOTE_RESULT)
    }

    @Test
    fun `should emit Loading, Success, Finish Status when we fetch local and then remote`() = coroutinesRule.runBlocking {
        //when
        val ld = repositoryLiveData<String, String>(
                localResult = { delegation.localResult() },
                shouldFetch = { delegation.shouldFetch() },
                remoteResult = { delegation.remoteResult() },
                saveFetchResult = { delegation.shouldFetch() },
                dispatcher = coroutinesRule.dispatcher
        )
        //then
        ld.captureValues {
            assertEquals(arrayListOf(Result.Loading,
                                     Result.Success(LOCAL_RESULT),
                                     Result.Finish), values)
        }
    }

    @Test
    fun `should emit Loading,Success, Error, Success, Finish Status when we fetch remote but fail`() = coroutinesRule.runBlocking {
        givenSuspended { delegation.remoteResult() }
                .willThrow(RuntimeException(REMOTE_CRASH))
        //when
        val ld = repositoryLiveData<String, String>(
                localResult = { delegation.localResult() },
                shouldFetch = { delegation.shouldFetch() },
                remoteResult = { delegation.remoteResult() },
                saveFetchResult = { delegation.shouldFetch() },
                dispatcher = coroutinesRule.dispatcher
        )
        //then
        ld.captureValues {
            assertEquals(arrayListOf(Result.Loading,
                                     Result.Success(LOCAL_RESULT),
                                     Result.Error(REMOTE_CRASH),
                                     Result.Success(LOCAL_RESULT),
                                     Result.Finish
            ), values)
        }
    }


}

LiveDataTestUtil.kt(此思路来自aac sample, kotlin-coroutine

fun <T> LiveData<T>.getOrAwaitValue(
        time: Long = 2,
        timeUnit: TimeUnit = TimeUnit.SECONDS,
        afterObserve: () -> Unit = {}
): T {
    var data: T? = null
    val latch = CountDownLatch(1)
    val observer = object : Observer<T> {
        override fun onChanged(o: T?) {
            data = o
            latch.countDown()
            this@getOrAwaitValue.removeObserver(this)
        }
    }
    this.observeForever(observer)

    afterObserve.invoke()

    // Don't wait indefinitely if the LiveData is not set.
    if (!latch.await(time, timeUnit)) {
        this.removeObserver(observer)
        throw TimeoutException("LiveData value was never set.")
    }

    @Suppress("UNCHECKED_CAST")
    return data as T
}

class LiveDataValueCapture<T> {

    val lock = Any()

    private val _values = mutableListOf<T?>()
    val values: List<T?>
        get() = synchronized(lock) {
            _values.toList() // copy to avoid returning reference to mutable list
        }

    fun addValue(value: T?) = synchronized(lock) {
        _values += value
    }
}

inline fun <T> LiveData<T>.captureValues(block: LiveDataValueCapture<T>.() -> Unit) {
    val capture = LiveDataValueCapture<T>()
    val observer = Observer<T> {
        capture.addValue(it)
    }
    observeForever(observer)
    try {
        capture.block()
    } finally {
        removeObserver(observer)
    }
}