Skip to content

Commit

Permalink
fix: MLS groups are not usable if I have a proteus client registered …
Browse files Browse the repository at this point in the history
…to my account but have MLS on current client [WPB-15192] (#3197)

* fix: MLS groups are not usable if I have a proteus client registered to my account but have MLS on current client [WPB-15192]

* detekt

* detekt

* fix: pr comments

---------

Co-authored-by: yamilmedina <[email protected]>
  • Loading branch information
2 people authored and github-actions[bot] committed Dec 31, 2024
1 parent 6ddef73 commit 2818d42
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.wire.kalium.logger.obfuscateDomain
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.MemberMapper
import com.wire.kalium.logic.data.conversation.Recipient
import com.wire.kalium.logic.data.conversation.mls.NameAndHandle
Expand Down Expand Up @@ -167,6 +168,7 @@ interface UserRepository {
suspend fun getNameAndHandle(userId: UserId): Either<StorageFailure, NameAndHandle>
suspend fun migrateUserToTeam(teamName: String): Either<CoreFailure, CreateUserTeam>
suspend fun updateTeamId(userId: UserId, teamId: TeamId): Either<StorageFailure, Unit>
suspend fun isClientMlsCapable(userId: UserId, clientId: ClientId): Either<StorageFailure, Boolean>
}

@Suppress("LongParameterList", "TooManyFunctions")
Expand Down Expand Up @@ -668,6 +670,10 @@ internal class UserDataSource internal constructor(
userDAO.updateTeamId(userId.toDao(), teamId.value)
}

override suspend fun isClientMlsCapable(userId: UserId, clientId: ClientId): Either<StorageFailure, Boolean> = wrapStorageRequest {
clientDAO.isMLSCapable(userId.toDao(), clientId.value)
}

companion object {
internal const val SELF_USER_ID_KEY = "selfUserID"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ class ConversationScope internal constructor(
get() = ObserveIsSelfUserMemberUseCaseImpl(conversationRepository, selfUserId)

val observeConversationInteractionAvailabilityUseCase: ObserveConversationInteractionAvailabilityUseCase
get() = ObserveConversationInteractionAvailabilityUseCase(conversationRepository, userRepository)
get() = ObserveConversationInteractionAvailabilityUseCase(
conversationRepository,
selfUserId = selfUserId,
selfClientIdProvider = currentClientIdProvider,
userRepository = userRepository
)

val deleteTeamConversation: DeleteTeamConversationUseCase
get() = DeleteTeamConversationUseCaseImpl(selfTeamIdProvider, teamRepository, conversationRepository)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.InteractionAvailability
import com.wire.kalium.logic.data.conversation.interactionAvailability
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.getOrElse
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.util.KaliumDispatcher
import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext

Expand All @@ -48,6 +51,8 @@ import kotlinx.coroutines.withContext
class ObserveConversationInteractionAvailabilityUseCase internal constructor(
private val conversationRepository: ConversationRepository,
private val userRepository: UserRepository,
private val selfUserId: UserId,
private val selfClientIdProvider: CurrentClientIdProvider,
private val dispatcher: KaliumDispatcher = KaliumDispatcherImpl,
) {

Expand All @@ -56,13 +61,21 @@ class ObserveConversationInteractionAvailabilityUseCase internal constructor(
* @return an [IsInteractionAvailableResult] containing Success or Failure cases
*/
suspend operator fun invoke(conversationId: ConversationId): Flow<IsInteractionAvailableResult> = withContext(dispatcher.io) {
conversationRepository.observeConversationDetailsById(conversationId).combine(
userRepository.observeSelfUser()
) { conversation, selfUser ->
conversation to selfUser
}.map { (eitherConversation, selfUser) ->

val isSelfClientMlsCapable = selfClientIdProvider().flatMap {
userRepository.isClientMlsCapable(selfUserId, it)
}.getOrElse {
return@withContext flow { IsInteractionAvailableResult.Failure(it) }
}

kaliumLogger.withTextTag("ObserveConversationInteractionAvailabilityUseCase").d("isSelfClientMlsCapable $isSelfClientMlsCapable")

conversationRepository.observeConversationDetailsById(conversationId).map { eitherConversation ->
eitherConversation.fold({ failure -> IsInteractionAvailableResult.Failure(failure) }, { conversationDetails ->
val isProtocolSupported = doesUserSupportConversationProtocol(conversationDetails, selfUser)
val isProtocolSupported = doesUserSupportConversationProtocol(
conversationDetails = conversationDetails,
isSelfClientMlsCapable = isSelfClientMlsCapable
)
if (!isProtocolSupported) { // short-circuit to Unsupported Protocol if it's the case
return@fold IsInteractionAvailableResult.Success(InteractionAvailability.UNSUPPORTED_PROTOCOL)
}
Expand All @@ -74,19 +87,12 @@ class ObserveConversationInteractionAvailabilityUseCase internal constructor(

private fun doesUserSupportConversationProtocol(
conversationDetails: ConversationDetails,
selfUser: SelfUser
): Boolean {
val protocolInfo = conversationDetails.conversation.protocol
val acceptableProtocols = when (protocolInfo) {
is Conversation.ProtocolInfo.MLS -> setOf(SupportedProtocol.MLS)
// Messages in mixed conversations are sent through Proteus
is Conversation.ProtocolInfo.Mixed -> setOf(SupportedProtocol.PROTEUS)
Conversation.ProtocolInfo.Proteus -> setOf(SupportedProtocol.PROTEUS)
}
val isProtocolSupported = selfUser.supportedProtocols?.any { supported ->
acceptableProtocols.contains(supported)
} ?: false
return isProtocolSupported
isSelfClientMlsCapable: Boolean
): Boolean = when (conversationDetails.conversation.protocol) {
is Conversation.ProtocolInfo.MLS -> isSelfClientMlsCapable
// Messages in mixed conversations are sent through Proteus
is Conversation.ProtocolInfo.Mixed,
Conversation.ProtocolInfo.Proteus -> true
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,21 @@ package com.wire.kalium.logic.feature.conversation

import app.cash.turbine.test
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.InteractionAvailability
import com.wire.kalium.logic.data.user.ConnectionState
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestConversationDetails
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.right
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.test_util.testKaliumDispatcher
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangement
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangement
import com.wire.kalium.logic.util.arrangement.repository.ConversationRepositoryArrangementImpl
import com.wire.kalium.logic.util.arrangement.repository.UserRepositoryArrangement
Expand All @@ -41,6 +46,7 @@ import io.mockative.once
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.test.runTest
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
Expand All @@ -52,6 +58,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withSelfUserBeingMemberOfConversation(isMember = true)
}
Expand All @@ -76,6 +83,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val (arrangement, observeConversationInteractionAvailability) = arrange {
dispatcher = testKaliumDispatcher
withSelfUserBeingMemberOfConversation(isMember = false)
withIsClientMlsCapable(false.right())
}

observeConversationInteractionAvailability(conversationId).test {
Expand All @@ -96,6 +104,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withGroupConversationError()
}
Expand All @@ -118,6 +127,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withBlockedUserConversation()
}
Expand All @@ -132,14 +142,14 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {

awaitComplete()
}

}

@Test
fun givenOtherUserIsDeleted_whenInvokingInteractionForConversation_thenInteractionShouldBeDisabled() = runTest {
val conversationId = TestConversation.ID

val (arrangement, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withDeletedUserConversation()
}
Expand All @@ -156,11 +166,12 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
}
}

@Ignore // is this really a case that a client does not support Proteus
@Test
fun givenProteusConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest {
testProtocolSupport(
conversationProtocolInfo = Conversation.ProtocolInfo.Proteus,
userSupportedProtocols = setOf(SupportedProtocol.MLS),
isMlsCapable = true.right(),
expectedResult = InteractionAvailability.UNSUPPORTED_PROTOCOL
)
}
Expand All @@ -169,7 +180,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMLSConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MLS_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
isMlsCapable = false.right(),
expectedResult = InteractionAvailability.UNSUPPORTED_PROTOCOL
)
}
Expand All @@ -178,7 +189,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMixedConversationAndUserSupportsOnlyMLS_whenObserving_thenShouldReturnUnsupportedProtocol() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MIXED_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
isMlsCapable = false.right(),
expectedResult = InteractionAvailability.ENABLED
)
}
Expand All @@ -187,7 +198,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMixedConversationAndUserSupportsProteus_whenObserving_thenShouldReturnEnabled() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MIXED_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
isMlsCapable = false.right(),
expectedResult = InteractionAvailability.ENABLED
)
}
Expand All @@ -196,35 +207,35 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenMLSConversationAndUserSupportsMLS_whenObserving_thenShouldReturnEnabled() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.MLS_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.MLS),
expectedResult = InteractionAvailability.ENABLED
expectedResult = InteractionAvailability.ENABLED,
isMlsCapable = true.right()
)
}

@Test
fun givenProteusConversationAndUserSupportsProteus_whenObserving_thenShouldReturnEnabled() = runTest {
testProtocolSupport(
conversationProtocolInfo = TestConversation.PROTEUS_PROTOCOL_INFO,
userSupportedProtocols = setOf(SupportedProtocol.PROTEUS),
expectedResult = InteractionAvailability.ENABLED
expectedResult = InteractionAvailability.ENABLED,
isMlsCapable = false.right()
)
}

private suspend fun CoroutineScope.testProtocolSupport(
conversationProtocolInfo: Conversation.ProtocolInfo,
userSupportedProtocols: Set<SupportedProtocol>,
isMlsCapable: Either<StorageFailure, Boolean>,
expectedResult: InteractionAvailability
) {
val convId = TestConversationDetails.CONVERSATION_GROUP.conversation.id
val (_, observeConversationInteractionAvailabilityUseCase) = arrange {
withIsClientMlsCapable(isMlsCapable)
dispatcher = testKaliumDispatcher
val proteusGroupDetails = TestConversationDetails.CONVERSATION_GROUP.copy(
conversation = TestConversationDetails.CONVERSATION_GROUP.conversation.copy(
protocol = conversationProtocolInfo
)
)
withObserveConversationDetailsByIdReturning(Either.Right(proteusGroupDetails))
withObservingSelfUserReturning(flowOf(TestUser.SELF.copy(supportedProtocols = userSupportedProtocols)))
}

observeConversationInteractionAvailabilityUseCase(convId).test {
Expand All @@ -241,6 +252,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
val (_, observeConversationInteractionAvailability) = arrange {
dispatcher = testKaliumDispatcher
withLegalHoldOneOnOneConversation(Conversation.LegalHoldStatus.ENABLED)
withIsClientMlsCapable(false.right())
}
observeConversationInteractionAvailability(conversationId).test {
val interactionResult = awaitItem()
Expand All @@ -253,6 +265,7 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
fun givenConversationLegalHoldIsDegraded_whenInvokingInteractionForConversation_thenInteractionShouldBeLegalHold() = runTest {
val conversationId = TestConversation.ID
val (_, observeConversationInteractionAvailability) = arrange {
withIsClientMlsCapable(false.right())
dispatcher = testKaliumDispatcher
withLegalHoldOneOnOneConversation(Conversation.LegalHoldStatus.DEGRADED)
}
Expand All @@ -266,10 +279,12 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
private class Arrangement(
private val configure: suspend Arrangement.() -> Unit
) : UserRepositoryArrangement by UserRepositoryArrangementImpl(),
ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl() {
ConversationRepositoryArrangement by ConversationRepositoryArrangementImpl(),
CurrentClientIdProviderArrangement by CurrentClientIdProviderArrangementImpl() {

var dispatcher: KaliumDispatcher = TestKaliumDispatcher

val selfUser = UserId("self_value", "self_domain")
suspend fun withSelfUserBeingMemberOfConversation(isMember: Boolean) = apply {
withObserveConversationDetailsByIdReturning(
Either.Right(TestConversationDetails.CONVERSATION_GROUP.copy(isSelfUserMember = isMember))
Expand Down Expand Up @@ -315,17 +330,15 @@ class ObserveConversationInteractionAvailabilityUseCaseTest {
}

suspend fun arrange(): Pair<Arrangement, ObserveConversationInteractionAvailabilityUseCase> = run {
withObservingSelfUserReturning(
flowOf(
TestUser.SELF.copy(supportedProtocols = setOf(SupportedProtocol.MLS, SupportedProtocol.PROTEUS))
)
)
withCurrentClientIdSuccess(ClientId("client_id"))
configure()
this@Arrangement to ObserveConversationInteractionAvailabilityUseCase(
conversationRepository = conversationRepository,
userRepository = userRepository,
dispatcher = dispatcher
)
dispatcher = dispatcher,
selfUserId = selfUser,
selfClientIdProvider = currentClientIdProvider
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.wire.kalium.logic.util.arrangement.repository

import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.mls.NameAndHandle
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.QualifiedID
Expand Down Expand Up @@ -95,6 +96,12 @@ internal interface UserRepositoryArrangement {
)

suspend fun withNameAndHandle(result: Either<StorageFailure, NameAndHandle>, userId: Matcher<UserId> = AnyMatcher(valueOf()))

suspend fun withIsClientMlsCapable(
result: Either<StorageFailure, Boolean>,
userId: Matcher<UserId> = AnyMatcher(valueOf()),
clientId: Matcher<ClientId> = AnyMatcher(valueOf())
)
}

@Suppress("INAPPLICABLE_JVM_NAME")
Expand Down Expand Up @@ -233,4 +240,13 @@ internal open class UserRepositoryArrangementImpl : UserRepositoryArrangement {
override suspend fun withNameAndHandle(result: Either<StorageFailure, NameAndHandle>, userId: Matcher<UserId>) {
coEvery { userRepository.getNameAndHandle(matches { userId.matches(it) }) }.returns(result)
}

override suspend fun withIsClientMlsCapable(result: Either<StorageFailure, Boolean>, userId: Matcher<UserId>, clientId: Matcher<ClientId>) {
coEvery {
userRepository.isClientMlsCapable(
userId = matches { userId.matches(it) },
clientId = matches { clientId.matches(it) }
)
}.returns(result)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ SELECT * FROM Client WHERE user_id = :user_id AND id = :client_id;
deleteClientsOfUserExcept:
DELETE FROM Client WHERE user_id = :user_id AND id NOT IN :exception_ids;

isClientMLSCapable:
SELECT is_mls_capable FROM Client WHERE user_id = :user_id AND id = :client_id;

tryMarkAsInvalid:
UPDATE OR IGNORE Client SET is_valid = 0 WHERE user_id = :user_id AND id IN :clientId_List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ interface ClientDAO {
): Map<QualifiedIDEntity, List<Client>>

suspend fun selectAllClients(): Map<QualifiedIDEntity, List<Client>>
suspend fun isMLSCapable(userId: QualifiedIDEntity, clientId: String): Boolean?
}
Loading

0 comments on commit 2818d42

Please sign in to comment.