diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index f29d98f0ce9..e7e1a1637a8 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1911,6 +1911,9 @@ class UserSessionScope internal constructor( val search: SearchScope by lazy { SearchScope( + mlsPublicKeysRepository = mlsPublicKeysRepository, + getDefaultProtocol = getDefaultProtocol, + getConversationProtocolInfo = conversations.getConversationProtocolInfo, searchUserRepository = searchUserRepository, selfUserId = userId, sessionRepository = globalScope.sessionRepository, diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/IsFederationSearchAllowedUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/IsFederationSearchAllowedUseCase.kt new file mode 100644 index 00000000000..964b17c01e8 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/IsFederationSearchAllowedUseCase.kt @@ -0,0 +1,82 @@ +/* + * Wire + * Copyright (C) 2024 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.search + +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.mls.MLSPublicKeys +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase +import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.util.KaliumDispatcher +import com.wire.kalium.util.KaliumDispatcherImpl +import kotlinx.coroutines.withContext + +/** + * Check if FederatedSearchIsAllowed according to MLS configuration of the backend + * and the conversation protocol if a [ConversationId] is provided. + */ +interface IsFederationSearchAllowedUseCase { + suspend operator fun invoke(conversationId: ConversationId?): Boolean +} + +@Suppress("FunctionNaming") +internal fun IsFederationSearchAllowedUseCase( + mlsPublicKeysRepository: MLSPublicKeysRepository, + getDefaultProtocol: GetDefaultProtocolUseCase, + getConversationProtocolInfo: GetConversationProtocolInfoUseCase, + dispatcher: KaliumDispatcher = KaliumDispatcherImpl +) = object : IsFederationSearchAllowedUseCase { + + override suspend operator fun invoke(conversationId: ConversationId?): Boolean = withContext(dispatcher.io) { + val isMlsConfiguredForBackend = hasMLSKeysConfiguredForBackend() + when (isMlsConfiguredForBackend) { + true -> isConversationProtocolAbleToFederate(conversationId) + false -> true + } + } + + private suspend fun hasMLSKeysConfiguredForBackend(): Boolean { + return when (val mlsKeysResult = mlsPublicKeysRepository.getKeys()) { + is Either.Left -> false + is Either.Right -> { + val mlsKeys: MLSPublicKeys = mlsKeysResult.value + mlsKeys.removal != null && mlsKeys.removal?.isNotEmpty() == true + } + } + } + + /** + * MLS is enabled, then we need to check if the protocol for the conversation is able to federate. + */ + private suspend fun isConversationProtocolAbleToFederate(conversationId: ConversationId?): Boolean { + val isProteusTeam = getDefaultProtocol() == SupportedProtocol.PROTEUS + val isOtherDomainAllowed: Boolean = conversationId?.let { + when (val result = getConversationProtocolInfo(it)) { + is GetConversationProtocolInfoUseCase.Result.Failure -> !isProteusTeam + + is GetConversationProtocolInfoUseCase.Result.Success -> + !isProteusTeam && result.protocolInfo !is Conversation.ProtocolInfo.Proteus + } + } ?: !isProteusTeam + return isOtherDomainAllowed + } + +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/SearchScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/SearchScope.kt index 662514c9447..bb5c11a9ff9 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/SearchScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/search/SearchScope.kt @@ -17,12 +17,19 @@ */ package com.wire.kalium.logic.feature.search +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.data.publicuser.SearchUserRepository import com.wire.kalium.logic.data.session.SessionRepository import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase +import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase import com.wire.kalium.logic.featureFlags.KaliumConfigs +@Suppress("LongParameterList") class SearchScope internal constructor( + private val mlsPublicKeysRepository: MLSPublicKeysRepository, + private val getDefaultProtocol: GetDefaultProtocolUseCase, + private val getConversationProtocolInfo: GetConversationProtocolInfoUseCase, private val searchUserRepository: SearchUserRepository, private val sessionRepository: SessionRepository, private val selfUserId: UserId, @@ -42,4 +49,7 @@ class SearchScope internal constructor( kaliumConfigs.maxRemoteSearchResultCount ) val federatedSearchParser: FederatedSearchParser get() = FederatedSearchParser(sessionRepository, selfUserId) + + val isFederationSearchAllowedUseCase: IsFederationSearchAllowedUseCase + get() = IsFederationSearchAllowedUseCase(mlsPublicKeysRepository, getDefaultProtocol, getConversationProtocolInfo) } diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/search/IsFederationSearchAllowedUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/search/IsFederationSearchAllowedUseCaseTest.kt new file mode 100644 index 00000000000..0ac583443fd --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/search/IsFederationSearchAllowedUseCaseTest.kt @@ -0,0 +1,166 @@ +/* + * Wire + * Copyright (C) 2024 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.feature.search + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.data.mls.MLSPublicKeys +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository +import com.wire.kalium.logic.data.user.SupportedProtocol +import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase +import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestConversation.PROTEUS_PROTOCOL_INFO +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.util.KaliumDispatcherImpl +import io.mockative.Mock +import io.mockative.any +import io.mockative.coEvery +import io.mockative.coVerify +import io.mockative.every +import io.mockative.mock +import io.mockative.once +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class IsFederationSearchAllowedUseCaseTest { + + @Test + fun givenMLSIsNotConfigured_whenInvokingIsFederationSearchAllowed_thenReturnTrue() = runTest { + val (arrangement, isFederationSearchAllowedUseCase) = Arrangement() + .withMLSConfiguredForBackend(isConfigured = false) + .arrange() + + val isAllowed = isFederationSearchAllowedUseCase(conversationId = null) + + assertEquals(true, isAllowed) + coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once) + coVerify { arrangement.getDefaultProtocol.invoke() }.wasNotInvoked() + coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasNotInvoked() + } + + @Test + fun givenMLSIsConfiguredAndAMLSTeamWithEmptyKeys_whenInvokingIsFederationSearchAllowed_thenReturnTrue() = runTest { + val (arrangement, isFederationSearchAllowedUseCase) = Arrangement() + .withEmptyMlsKeys() + .arrange() + + val isAllowed = isFederationSearchAllowedUseCase(conversationId = null) + + assertEquals(true, isAllowed) + coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once) + coVerify { arrangement.getDefaultProtocol.invoke() }.wasNotInvoked() + coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasNotInvoked() + } + + @Test + fun givenMLSIsConfiguredAndAMLSTeam_whenInvokingIsFederationSearchAllowed_thenReturnTrue() = runTest { + val (arrangement, isFederationSearchAllowedUseCase) = Arrangement() + .withMLSConfiguredForBackend(isConfigured = true) + .withDefaultProtocol(SupportedProtocol.MLS) + .arrange() + + val isAllowed = isFederationSearchAllowedUseCase(conversationId = null) + + assertEquals(true, isAllowed) + coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once) + coVerify { arrangement.getDefaultProtocol.invoke() }.wasInvoked(once) + coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasNotInvoked() + } + + @Test + fun givenMLSIsConfiguredAndAMLSTeamAndProteusProtocol_whenInvokingIsFederationSearchAllowed_thenReturnFalse() = runTest { + val (arrangement, isFederationSearchAllowedUseCase) = Arrangement() + .withMLSConfiguredForBackend(isConfigured = true) + .withDefaultProtocol(SupportedProtocol.MLS) + .withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(PROTEUS_PROTOCOL_INFO)) + .arrange() + + val isAllowed = isFederationSearchAllowedUseCase(conversationId = TestConversation.ID) + + assertEquals(false, isAllowed) + coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once) + coVerify { arrangement.getDefaultProtocol.invoke() }.wasInvoked(once) + coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasInvoked(once) + } + + @Test + fun givenMLSIsConfiguredAndAProteusTeamAndProteusProtocol_whenInvokingIsFederationSearchAllowed_thenReturnFalse() = runTest { + val (arrangement, isFederationSearchAllowedUseCase) = Arrangement() + .withMLSConfiguredForBackend(isConfigured = true) + .withDefaultProtocol(SupportedProtocol.PROTEUS) + .withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(PROTEUS_PROTOCOL_INFO)) + .arrange() + + val isAllowed = isFederationSearchAllowedUseCase(conversationId = TestConversation.ID) + + assertEquals(false, isAllowed) + coVerify { arrangement.mlsPublicKeysRepository.getKeys() }.wasInvoked(once) + coVerify { arrangement.getDefaultProtocol.invoke() }.wasInvoked(once) + coVerify { arrangement.getConversationProtocolInfo.invoke(any()) }.wasInvoked(once) + } + + private class Arrangement { + + @Mock + val mlsPublicKeysRepository = mock(MLSPublicKeysRepository::class) + + @Mock + val getDefaultProtocol = mock(GetDefaultProtocolUseCase::class) + + @Mock + val getConversationProtocolInfo = mock(GetConversationProtocolInfoUseCase::class) + + private val MLS_PUBLIC_KEY = MLSPublicKeys( + removal = mapOf( + "ed25519" to "gRNvFYReriXbzsGu7zXiPtS8kaTvhU1gUJEV9rdFHVw=" + ) + ) + + fun withDefaultProtocol(protocol: SupportedProtocol) = apply { + every { getDefaultProtocol.invoke() }.returns(protocol) + } + + suspend fun withConversationProtocolInfo(protocolInfo: GetConversationProtocolInfoUseCase.Result) = apply { + coEvery { getConversationProtocolInfo(any()) }.returns(protocolInfo) + } + + suspend fun withMLSConfiguredForBackend(isConfigured: Boolean = true) = apply { + coEvery { mlsPublicKeysRepository.getKeys() }.returns( + if (isConfigured) { + Either.Right(MLS_PUBLIC_KEY) + } else { + Either.Left(CoreFailure.Unknown(RuntimeException("MLS is not configured"))) + } + ) + } + + suspend fun withEmptyMlsKeys() = apply { + coEvery { mlsPublicKeysRepository.getKeys() }.returns(Either.Right(MLSPublicKeys(emptyMap()))) + } + + fun arrange() = this to IsFederationSearchAllowedUseCase( + mlsPublicKeysRepository = mlsPublicKeysRepository, + getDefaultProtocol = getDefaultProtocol, + getConversationProtocolInfo = getConversationProtocolInfo, + dispatcher = KaliumDispatcherImpl + ) + } +} + +