diff --git a/courier-core/api/courier-core.api b/courier-core/api/courier-core.api index 11bf9339..21f4bfd6 100644 --- a/courier-core/api/courier-core.api +++ b/courier-core/api/courier-core.api @@ -72,6 +72,9 @@ public final class com/gojek/courier/extensions/CollectionExtensionsKt { public static final fun toImmutableSet (Ljava/util/Set;)Ljava/util/Set; } +public final class com/gojek/courier/extensions/StringExtensionsKt { +} + public final class com/gojek/courier/extensions/TimeUnitExtensionsKt { } diff --git a/courier-core/src/main/java/com/gojek/courier/extensions/StringExtensions.kt b/courier-core/src/main/java/com/gojek/courier/extensions/StringExtensions.kt new file mode 100644 index 00000000..64d366d9 --- /dev/null +++ b/courier-core/src/main/java/com/gojek/courier/extensions/StringExtensions.kt @@ -0,0 +1,9 @@ +package com.gojek.courier.extensions + +import androidx.annotation.RestrictTo + +@RestrictTo(RestrictTo.Scope.LIBRARY) +fun String.isWildCardTopic(): Boolean { + return startsWith("+/") || contains("/+/") || endsWith("/+") || equals("+") || + endsWith("/#") || equals("#") +} diff --git a/courier/src/main/java/com/gojek/courier/coordinator/Coordinator.kt b/courier/src/main/java/com/gojek/courier/coordinator/Coordinator.kt index 38e75f82..ba3a15c8 100644 --- a/courier/src/main/java/com/gojek/courier/coordinator/Coordinator.kt +++ b/courier/src/main/java/com/gojek/courier/coordinator/Coordinator.kt @@ -16,10 +16,14 @@ import com.gojek.mqtt.client.model.ConnectionState import com.gojek.mqtt.client.model.MqttMessage import com.gojek.mqtt.event.EventHandler import com.gojek.mqtt.event.MqttEvent +import com.gojek.mqtt.event.MqttEvent.MqttSubscribeFailureEvent import io.reactivex.BackpressureStrategy import io.reactivex.Flowable import io.reactivex.FlowableOnSubscribe +import io.reactivex.disposables.CompositeDisposable import io.reactivex.schedulers.Schedulers +import io.reactivex.subjects.PublishSubject +import org.eclipse.paho.client.mqttv3.MqttException import org.reactivestreams.Subscriber import org.reactivestreams.Subscription @@ -28,6 +32,18 @@ internal class Coordinator( private val logger: ILogger ) : StubInterface.Callback { + private val eventSubject = PublishSubject.create { emitter -> + val eventHandler = object : EventHandler { + override fun onEvent(mqttEvent: MqttEvent) { + if (emitter.isDisposed.not()) { + emitter.onNext(mqttEvent) + } + } + } + client.addEventHandler(eventHandler) + emitter.setCancellable { client.removeEventHandler(eventHandler) } + } + @Synchronized override fun send(stubMethod: StubMethod.Send, args: Array): Any { logger.d("Coordinator", "Send method invoked") @@ -106,7 +122,15 @@ internal class Coordinator( } } client.addMessageListener(topic, listener) - emitter.setCancellable { client.removeMessageListener(topic, listener) } + val eventDisposable = eventSubject.filter { event -> + isInvalidSubscriptionFailureEvent(event, topic) + }.subscribe { + client.removeMessageListener(topic, listener) + } + emitter.setCancellable { + client.removeMessageListener(topic, listener) + eventDisposable.dispose() + } }, BackpressureStrategy.BUFFER ) @@ -166,9 +190,22 @@ internal class Coordinator( } } } + val eventDisposable = CompositeDisposable() for (topic in topicList) { client.addMessageListener(topic.first, listener) - emitter.setCancellable { client.removeMessageListener(topic.first, listener) } + eventDisposable.add( + eventSubject.filter { event -> + isInvalidSubscriptionFailureEvent(event, topic.first) + }.subscribe { + client.removeMessageListener(topic.first, listener) + } + ) + } + emitter.setCancellable { + for (topic in topicList) { + client.removeMessageListener(topic.first, listener) + eventDisposable.dispose() + } } }, BackpressureStrategy.BUFFER @@ -249,4 +286,10 @@ internal class Coordinator( null } } + + private fun isInvalidSubscriptionFailureEvent(event: MqttEvent, topic: String): Boolean { + return event is MqttSubscribeFailureEvent && + event.topics.containsKey(topic) && + event.exception.reasonCode == MqttException.REASON_CODE_INVALID_SUBSCRIPTION.toInt() + } } diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/client/IncomingMsgControllerImpl.kt b/mqtt-client/src/main/java/com/gojek/mqtt/client/IncomingMsgControllerImpl.kt index d36a1688..f8782c99 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/client/IncomingMsgControllerImpl.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/client/IncomingMsgControllerImpl.kt @@ -1,6 +1,7 @@ package com.gojek.mqtt.client import com.gojek.courier.extensions.fromSecondsToNanos +import com.gojek.courier.extensions.isWildCardTopic import com.gojek.courier.logging.ILogger import com.gojek.courier.utils.Clock import com.gojek.mqtt.client.listener.MessageListener @@ -47,6 +48,8 @@ internal class IncomingMsgControllerImpl( private val listenerMap = ConcurrentHashMap>() + private val wildcardTopicListenerMap = ConcurrentHashMap>() + private var cleanupFuture: ScheduledFuture<*>? = null override fun triggerHandleMessage() { @@ -64,40 +67,68 @@ internal class IncomingMsgControllerImpl( @Synchronized override fun registerListener(topic: String, listener: MessageListener) { - listenerMap[topic] = (listenerMap[topic] ?: emptyList()) + listener + if (topic.isWildCardTopic()) { + wildcardTopicListenerMap[topic] = (wildcardTopicListenerMap[topic] ?: emptyList()) + listener + } else { + listenerMap[topic] = (listenerMap[topic] ?: emptyList()) + listener + } triggerHandleMessage() } @Synchronized override fun unregisterListener(topic: String, listener: MessageListener) { - listenerMap[topic] = (listenerMap[topic] ?: emptyList()) - listener - if (listenerMap[topic]!!.isEmpty()) { - listenerMap.remove(topic) + if (topic.isWildCardTopic()) { + wildcardTopicListenerMap[topic] = (wildcardTopicListenerMap[topic] ?: emptyList()) - listener + if (wildcardTopicListenerMap[topic]!!.isEmpty()) { + wildcardTopicListenerMap.remove(topic) + } + } else { + listenerMap[topic] = (listenerMap[topic] ?: emptyList()) - listener + if (listenerMap[topic]!!.isEmpty()) { + listenerMap.remove(topic) + } } } private inner class HandleMessage : Runnable { override fun run() { try { - if (listenerMap.keys.isEmpty()) { + if (listenerMap.keys.isEmpty() && wildcardTopicListenerMap.isEmpty()) { logger.d(TAG, "No listeners registered") return } val messages: List = mqttReceivePersistence.getAllIncomingMessagesWithTopicFilter(listenerMap.keys) - if (mqttUtils.isEmpty(messages)) { - logger.d(TAG, "No Messages in Table") - return - } val deletedMsgIds = mutableListOf() for (message in messages) { logger.d(TAG, "Going to process ${message.messageId}") - val listenersNotified = notifyListeners(message) + val listenersNotified = notifyListeners(message, listenerMap[message.topic]!!) if (listenersNotified) { deletedMsgIds.add(message.messageId) } logger.d(TAG, "Successfully Processed Message ${message.messageId}") } + // processing messages for wildcard topic subscription + for (wildCardTopic in wildcardTopicListenerMap.keys()) { + val topicForDBQuery = parseWildCardTopicForDBQuery(wildCardTopic) + val wildcardMessages: List = + mqttReceivePersistence.getAllIncomingMessagesForWildCardTopic(topicForDBQuery) + for (message in wildcardMessages) { + logger.d(TAG, "Going to process ${message.messageId}") + val wildCardTopicRegex = parseWildCardTopicForRegex(wildCardTopic) + if (wildCardTopicRegex.matches(message.topic)) { + logger.d(TAG, "Wildcard topic: $wildCardTopic matches ${message.topic}") + val listenersNotified = + notifyListeners(message, wildcardTopicListenerMap[wildCardTopic]!!) + if (listenersNotified) { + deletedMsgIds.add(message.messageId) + } + } else { + logger.d(TAG, "Wildcard topic: $wildCardTopic does not match ${message.topic}") + } + logger.d(TAG, "Successfully Processed Message ${message.messageId}") + } + } if (deletedMsgIds.isNotEmpty()) { val deletedMessagesCount = deleteMessages(deletedMsgIds) logger.d(TAG, "Deleted $deletedMessagesCount messages") @@ -112,6 +143,18 @@ internal class IncomingMsgControllerImpl( } } + private fun parseWildCardTopicForDBQuery(topic: String): String { + var updatedTopic: String = topic.replace("+", "%") + updatedTopic = updatedTopic.replace("#", "%") + return updatedTopic + } + + private fun parseWildCardTopicForRegex(topic: String): Regex { + var updatedTopic: String = topic.replace("+", "[^\\/]+") + updatedTopic = updatedTopic.replace("#", "([^\\/]+(\\/?[^\\/])*)+") + return Regex(updatedTopic) + } + private inner class CleanupExpiredMessages : Runnable { override fun run() { logger.d(TAG, "Deleting expired messages") @@ -123,10 +166,10 @@ internal class IncomingMsgControllerImpl( } } - private fun notifyListeners(message: MqttReceivePacket): Boolean { + private fun notifyListeners(message: MqttReceivePacket, listeners: List): Boolean { var notified = false try { - listenerMap[message.topic]!!.forEach { + listeners.forEach { notified = true it.onMessageReceived(message.toMqttMessage()) } diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/connection/MqttConnection.kt b/mqtt-client/src/main/java/com/gojek/mqtt/connection/MqttConnection.kt index c699db29..8d040836 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/connection/MqttConnection.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/connection/MqttConnection.kt @@ -515,6 +515,7 @@ internal class MqttConnection( ), timeTakenMillis = (clock.nanoTime() - subscribeStartTime).fromNanosToMillis() ) + subscriptionStore.getListener().onInvalidTopicsSubscribeFailure(topicMap) } } } @@ -546,6 +547,7 @@ internal class MqttConnection( ), timeTakenMillis = (clock.nanoTime() - unsubscribeStartTime).fromNanosToMillis() ) + subscriptionStore.getListener().onInvalidTopicsUnsubscribeFailure(topics) } } } @@ -576,11 +578,12 @@ internal class MqttConnection( connectionConfig.connectionEventHandler.onMqttSubscribeFailure( topics = failTopicMap, timeTakenMillis = (clock.nanoTime() - context.startTime).fromNanosToMillis(), - throwable = MqttException(MqttException.REASON_CODE_INVALID_SUBSCRIPTION.toInt()) + throwable = MqttException(REASON_CODE_INVALID_SUBSCRIPTION.toInt()) ) } subscriptionStore.getListener().onTopicsSubscribed(successTopicMap) + subscriptionStore.getListener().onInvalidTopicsSubscribeFailure(failTopicMap) subscriptionPolicy.resetParams() } diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/persistence/IMqttReceivePersistence.kt b/mqtt-client/src/main/java/com/gojek/mqtt/persistence/IMqttReceivePersistence.kt index e7dcb280..c0a8a605 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/persistence/IMqttReceivePersistence.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/persistence/IMqttReceivePersistence.kt @@ -7,4 +7,5 @@ internal interface IMqttReceivePersistence { fun getAllIncomingMessagesWithTopicFilter(topics: Set): List fun removeReceivedMessages(messageIds: List): Int fun removeMessagesWithOlderTimestamp(timestampNanos: Long): Int + fun getAllIncomingMessagesForWildCardTopic(topic: String): List } diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/persistence/dao/IncomingMessagesDao.kt b/mqtt-client/src/main/java/com/gojek/mqtt/persistence/dao/IncomingMessagesDao.kt index 5aedd4cb..8f8783a9 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/persistence/dao/IncomingMessagesDao.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/persistence/dao/IncomingMessagesDao.kt @@ -13,6 +13,9 @@ internal interface IncomingMessagesDao { @Query("SELECT * from incoming_messages where topic in (:topics)") fun getAllMessagesWithTopicFilter(topics: Set): List + @Query("SELECT * from incoming_messages where topic LIKE :topic") + fun getAllIncomingMessagesForWildCardTopic(topic: String): List + @Query("DELETE from incoming_messages") fun clearAllMessages() diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/persistence/impl/PahoPersistence.kt b/mqtt-client/src/main/java/com/gojek/mqtt/persistence/impl/PahoPersistence.kt index caf29d31..c0492a66 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/persistence/impl/PahoPersistence.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/persistence/impl/PahoPersistence.kt @@ -86,6 +86,10 @@ internal class PahoPersistence(private val context: Context) : return incomingMessagesDao.getAllMessagesWithTopicFilter(topics) } + override fun getAllIncomingMessagesForWildCardTopic(topic: String): List { + return incomingMessagesDao.getAllIncomingMessagesForWildCardTopic(topic) + } + override fun removeReceivedMessages(messageIds: List): Int { return incomingMessagesDao.removeMessagesById(messageIds) } diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/InMemorySubscriptionStore.kt b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/InMemorySubscriptionStore.kt index f948d381..3d88cf94 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/InMemorySubscriptionStore.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/InMemorySubscriptionStore.kt @@ -4,7 +4,13 @@ import com.gojek.courier.QoS internal class InMemorySubscriptionStore : SubscriptionStore { private var state = State(mapOf()) - private val listener = object : SubscriptionStoreListener {} + private val listener = object : SubscriptionStoreListener { + override fun onInvalidTopicsSubscribeFailure(topicMap: Map) { + state = state.copy( + subscriptionTopics = state.subscriptionTopics - topicMap.keys + ) + } + } private data class State(val subscriptionTopics: Map) diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStore.kt b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStore.kt index 6afaffe5..ebb011f9 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStore.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStore.kt @@ -15,6 +15,17 @@ internal class PersistableSubscriptionStore(context: Context) : SubscriptionStor override fun onTopicsUnsubscribed(topics: Set) { onTopicsUnsubscribedInternal(topics) } + + override fun onInvalidTopicsSubscribeFailure(topicMap: Map) { + state = state.copy( + subscriptionTopics = state.subscriptionTopics - topicMap.keys, + pendingUnsubscribeTopics = state.pendingUnsubscribeTopics + ) + } + + override fun onInvalidTopicsUnsubscribeFailure(topics: Set) { + onTopicsUnsubscribedInternal(topics) + } } private data class State( diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStoreV2.kt b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStoreV2.kt index f26368e5..23c40372 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStoreV2.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/PersistableSubscriptionStoreV2.kt @@ -13,6 +13,17 @@ internal class PersistableSubscriptionStoreV2(context: Context) : SubscriptionSt override fun onTopicsUnsubscribed(topics: Set) { onTopicsUnsubscribedInternal(topics) } + + override fun onInvalidTopicsSubscribeFailure(topicMap: Map) { + state = state.copy( + subscriptionTopics = state.subscriptionTopics - topicMap.keys, + pendingUnsubscribeTopics = state.pendingUnsubscribeTopics + ) + } + + override fun onInvalidTopicsUnsubscribeFailure(topics: Set) { + onTopicsUnsubscribedInternal(topics) + } } private data class State( diff --git a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/SubscriptionStore.kt b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/SubscriptionStore.kt index e8da0fcb..9378dd5a 100644 --- a/mqtt-client/src/main/java/com/gojek/mqtt/subscription/SubscriptionStore.kt +++ b/mqtt-client/src/main/java/com/gojek/mqtt/subscription/SubscriptionStore.kt @@ -13,5 +13,7 @@ internal interface SubscriptionStore { internal interface SubscriptionStoreListener { fun onTopicsSubscribed(topicMap: Map) = Unit + fun onInvalidTopicsSubscribeFailure(topicMap: Map) = Unit fun onTopicsUnsubscribed(topics: Set) = Unit + fun onInvalidTopicsUnsubscribeFailure(topics: Set) = Unit }