diff --git a/app/src/main/kotlin/com/wire/android/di/accountScoped/SearchModule.kt b/app/src/main/kotlin/com/wire/android/di/accountScoped/SearchModule.kt index 02f80574ae6..e5bd7ef34a8 100644 --- a/app/src/main/kotlin/com/wire/android/di/accountScoped/SearchModule.kt +++ b/app/src/main/kotlin/com/wire/android/di/accountScoped/SearchModule.kt @@ -22,6 +22,7 @@ import com.wire.android.di.KaliumCoreLogic import com.wire.kalium.logic.CoreLogic import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.feature.search.FederatedSearchParser +import com.wire.kalium.logic.feature.search.IsFederationSearchAllowedUseCase import com.wire.kalium.logic.feature.search.SearchByHandleUseCase import com.wire.kalium.logic.feature.search.SearchScope import com.wire.kalium.logic.feature.search.SearchUsersUseCase @@ -53,4 +54,9 @@ class SearchModule { @ViewModelScoped @Provides fun provideFederatedSearchParser(searchScope: SearchScope): FederatedSearchParser = searchScope.federatedSearchParser + + @ViewModelScoped + @Provides + fun provideIsFederationSearchAllowedUseCase(searchScope: SearchScope): IsFederationSearchAllowedUseCase = + searchScope.isFederationSearchAllowedUseCase } diff --git a/app/src/main/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModel.kt b/app/src/main/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModel.kt index d40a68298d4..e7a87b34957 100644 --- a/app/src/main/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModel.kt +++ b/app/src/main/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModel.kt @@ -27,16 +27,13 @@ import com.wire.android.mapper.ContactMapper import com.wire.android.ui.home.newconversation.model.Contact import com.wire.android.ui.navArgs import com.wire.android.util.EMPTY -import com.wire.kalium.logic.data.conversation.Conversation -import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.feature.auth.ValidateUserHandleResult import com.wire.kalium.logic.feature.auth.ValidateUserHandleUseCase -import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase import com.wire.kalium.logic.feature.search.FederatedSearchParser +import com.wire.kalium.logic.feature.search.IsFederationSearchAllowedUseCase import com.wire.kalium.logic.feature.search.SearchByHandleUseCase import com.wire.kalium.logic.feature.search.SearchUserResult import com.wire.kalium.logic.feature.search.SearchUsersUseCase -import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.collections.immutable.ImmutableList import kotlinx.collections.immutable.ImmutableSet @@ -59,8 +56,7 @@ class SearchUserViewModel @Inject constructor( private val contactMapper: ContactMapper, private val federatedSearchParser: FederatedSearchParser, private val validateUserHandle: ValidateUserHandleUseCase, - private val getDefaultProtocol: GetDefaultProtocolUseCase, - private val getConversationProtocolInfo: GetConversationProtocolInfoUseCase, + private val isFederationSearchAllowed: IsFederationSearchAllowedUseCase, savedStateHandle: SavedStateHandle ) : ViewModel() { @@ -78,17 +74,7 @@ class SearchUserViewModel @Inject constructor( init { viewModelScope.launch { - val isProteusTeam = getDefaultProtocol() == SupportedProtocol.PROTEUS - - val isOtherDomainAllowed: Boolean = addMembersSearchNavArgs?.conversationId?.let { conversationId -> - when (val result = getConversationProtocolInfo(conversationId)) { - is GetConversationProtocolInfoUseCase.Result.Failure -> !isProteusTeam - - is GetConversationProtocolInfoUseCase.Result.Success -> - !isProteusTeam && result.protocolInfo !is Conversation.ProtocolInfo.Proteus - } - } ?: !isProteusTeam - + val isOtherDomainAllowed = isFederationSearchAllowed(addMembersSearchNavArgs?.conversationId) state = state.copy(isOtherDomainAllowed = isOtherDomainAllowed) } diff --git a/app/src/test/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModelTest.kt b/app/src/test/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModelTest.kt index f3bf3afed96..b3e2b1d1d4e 100644 --- a/app/src/test/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModelTest.kt +++ b/app/src/test/kotlin/com/wire/android/ui/home/conversations/search/SearchUserViewModelTest.kt @@ -33,17 +33,15 @@ import com.wire.kalium.logic.data.id.GroupID import com.wire.kalium.logic.data.mls.CipherSuite import com.wire.kalium.logic.data.publicuser.model.UserSearchDetails import com.wire.kalium.logic.data.user.ConnectionState -import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserId import com.wire.kalium.logic.data.user.type.UserType import com.wire.kalium.logic.feature.auth.ValidateUserHandleResult import com.wire.kalium.logic.feature.auth.ValidateUserHandleUseCase -import com.wire.kalium.logic.feature.conversation.GetConversationProtocolInfoUseCase import com.wire.kalium.logic.feature.search.FederatedSearchParser +import com.wire.kalium.logic.feature.search.IsFederationSearchAllowedUseCase import com.wire.kalium.logic.feature.search.SearchByHandleUseCase import com.wire.kalium.logic.feature.search.SearchUserResult import com.wire.kalium.logic.feature.search.SearchUsersUseCase -import com.wire.kalium.logic.feature.user.GetDefaultProtocolUseCase import io.mockk.MockKAnnotations import io.mockk.coEvery import io.mockk.coVerify @@ -138,10 +136,9 @@ class SearchUserViewModelTest { fun `given Proteus conversation and MLS team, when calling the searchUseCase, then otherDomain is not allowed`() = runTest { val conversationId = ConversationId("id", "domain") - val (arrangement, viewModel) = Arrangement() + val (_, viewModel) = Arrangement() .withAddMembersSearchNavArgs(AddMembersSearchNavArgs(conversationId, true)) - .withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(Conversation.ProtocolInfo.Proteus)) - .withDefaultProtocol(SupportedProtocol.MLS) + .withIsFederationSearchAllowedResult(false) .withIsValidHandleResult(ValidateUserHandleResult.Valid("")) .withFederatedSearchParserResult( FederatedSearchParser.Result( @@ -164,10 +161,9 @@ class SearchUserViewModelTest { fun `given MLS conversation and Proteus team, when calling the searchUseCase, then otherDomain is not allowed`() = runTest { val conversationId = ConversationId("id", "domain") - val (arrangement, viewModel) = Arrangement() + val (_, viewModel) = Arrangement() .withAddMembersSearchNavArgs(AddMembersSearchNavArgs(conversationId, true)) - .withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(mlsProtocol)) - .withDefaultProtocol(SupportedProtocol.PROTEUS) + .withIsFederationSearchAllowedResult(false) .withIsValidHandleResult(ValidateUserHandleResult.Valid("")) .withFederatedSearchParserResult( FederatedSearchParser.Result( @@ -190,10 +186,9 @@ class SearchUserViewModelTest { fun `given MLS conversation and MLS team, when calling the searchUseCase, then otherDomain is allowed`() = runTest { val conversationId = ConversationId("id", "domain") - val (arrangement, viewModel) = Arrangement() + val (_, viewModel) = Arrangement() .withAddMembersSearchNavArgs(AddMembersSearchNavArgs(conversationId, true)) - .withConversationProtocolInfo(GetConversationProtocolInfoUseCase.Result.Success(mlsProtocol)) - .withDefaultProtocol(SupportedProtocol.MLS) + .withIsFederationSearchAllowedResult(true) .withIsValidHandleResult(ValidateUserHandleResult.Valid("")) .withFederatedSearchParserResult( FederatedSearchParser.Result( @@ -360,10 +355,7 @@ class SearchUserViewModelTest { lateinit var searchByHandleUseCase: SearchByHandleUseCase @MockK - lateinit var getDefaultProtocolUseCase: GetDefaultProtocolUseCase - - @MockK - lateinit var getConversationProtocolInfo: GetConversationProtocolInfoUseCase + lateinit var isFederationSearchAllowedUseCase: IsFederationSearchAllowedUseCase init { MockKAnnotations.init(this, relaxUnitFun = true) @@ -371,10 +363,7 @@ class SearchUserViewModelTest { val user = args.get(0) as UserSearchDetails fromSearchUserResult(user) } - every { getDefaultProtocolUseCase() } returns SupportedProtocol.PROTEUS - coEvery { - getConversationProtocolInfo(any()) - } returns GetConversationProtocolInfoUseCase.Result.Success(Conversation.ProtocolInfo.Proteus) + withIsFederationSearchAllowedResult(false) } fun fromSearchUserResult(user: UserSearchDetails): Contact { @@ -427,12 +416,8 @@ class SearchUserViewModelTest { coEvery { searchByHandleUseCase(any(), any(), any()) } returns result } - suspend fun withConversationProtocolInfo(result: GetConversationProtocolInfoUseCase.Result) = apply { - coEvery { getConversationProtocolInfo(any()) } returns result - } - - fun withDefaultProtocol(protocol: SupportedProtocol) = apply { - every { getDefaultProtocolUseCase() } returns protocol + fun withIsFederationSearchAllowedResult(isAllowed: Boolean = true) = apply { + coEvery { isFederationSearchAllowedUseCase(any()) } returns isAllowed } private lateinit var searchUserViewModel: SearchUserViewModel @@ -444,8 +429,7 @@ class SearchUserViewModelTest { contactMapper = contactMapper, federatedSearchParser = federatedSearchParser, validateUserHandle = validateUserHandle, - getConversationProtocolInfo = getConversationProtocolInfo, - getDefaultProtocol = getDefaultProtocolUseCase, + isFederationSearchAllowed = isFederationSearchAllowedUseCase, savedStateHandle = savedStateHandle ) }.run {