Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSUE#46] Add support for wildcard subscription #75

Merged
merged 3 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions courier-core/api/courier-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public final class com/gojek/courier/extensions/CollectionExtensionsKt {
public static final fun toImmutableSet (Ljava/util/Set;)Ljava/util/Set;
}

public final class com/gojek/courier/extensions/StringExtensionsKt {
}

public final class com/gojek/courier/extensions/TimeUnitExtensionsKt {
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.gojek.courier.extensions

import androidx.annotation.RestrictTo

@RestrictTo(RestrictTo.Scope.LIBRARY)
fun String.isWildCardTopic(): Boolean {
return startsWith("+/") || contains("/+/") || endsWith("/+") || equals("+") ||
endsWith("/#") || equals("#")
}
47 changes: 45 additions & 2 deletions courier/src/main/java/com/gojek/courier/coordinator/Coordinator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ import com.gojek.mqtt.client.model.ConnectionState
import com.gojek.mqtt.client.model.MqttMessage
import com.gojek.mqtt.event.EventHandler
import com.gojek.mqtt.event.MqttEvent
import com.gojek.mqtt.event.MqttEvent.MqttSubscribeFailureEvent
import io.reactivex.BackpressureStrategy
import io.reactivex.Flowable
import io.reactivex.FlowableOnSubscribe
import io.reactivex.disposables.CompositeDisposable
import io.reactivex.schedulers.Schedulers
import io.reactivex.subjects.PublishSubject
import org.eclipse.paho.client.mqttv3.MqttException
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

Expand All @@ -28,6 +32,18 @@ internal class Coordinator(
private val logger: ILogger
) : StubInterface.Callback {

private val eventSubject = PublishSubject.create<MqttEvent> { emitter ->
val eventHandler = object : EventHandler {
override fun onEvent(mqttEvent: MqttEvent) {
if (emitter.isDisposed.not()) {
emitter.onNext(mqttEvent)
}
}
}
client.addEventHandler(eventHandler)
emitter.setCancellable { client.removeEventHandler(eventHandler) }
}

@Synchronized
override fun send(stubMethod: StubMethod.Send, args: Array<Any>): Any {
logger.d("Coordinator", "Send method invoked")
Expand Down Expand Up @@ -106,7 +122,15 @@ internal class Coordinator(
}
}
client.addMessageListener(topic, listener)
emitter.setCancellable { client.removeMessageListener(topic, listener) }
val eventDisposable = eventSubject.filter { event ->
isInvalidSubscriptionFailureEvent(event, topic)
}.subscribe {
client.removeMessageListener(topic, listener)
}
emitter.setCancellable {
client.removeMessageListener(topic, listener)
eventDisposable.dispose()
}
},
BackpressureStrategy.BUFFER
)
Expand Down Expand Up @@ -166,9 +190,22 @@ internal class Coordinator(
}
}
}
val eventDisposable = CompositeDisposable()
for (topic in topicList) {
client.addMessageListener(topic.first, listener)
emitter.setCancellable { client.removeMessageListener(topic.first, listener) }
eventDisposable.add(
eventSubject.filter { event ->
isInvalidSubscriptionFailureEvent(event, topic.first)
}.subscribe {
client.removeMessageListener(topic.first, listener)
}
)
}
emitter.setCancellable {
for (topic in topicList) {
client.removeMessageListener(topic.first, listener)
eventDisposable.dispose()
}
}
},
BackpressureStrategy.BUFFER
Expand Down Expand Up @@ -249,4 +286,10 @@ internal class Coordinator(
null
}
}

private fun isInvalidSubscriptionFailureEvent(event: MqttEvent, topic: String): Boolean {
return event is MqttSubscribeFailureEvent &&
event.topics.containsKey(topic) &&
event.exception.reasonCode == MqttException.REASON_CODE_INVALID_SUBSCRIPTION.toInt()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.gojek.mqtt.client

import com.gojek.courier.extensions.fromSecondsToNanos
import com.gojek.courier.extensions.isWildCardTopic
import com.gojek.courier.logging.ILogger
import com.gojek.courier.utils.Clock
import com.gojek.mqtt.client.listener.MessageListener
Expand Down Expand Up @@ -47,6 +48,8 @@ internal class IncomingMsgControllerImpl(

private val listenerMap = ConcurrentHashMap<String, List<MessageListener>>()

private val wildcardTopicListenerMap = ConcurrentHashMap<String, List<MessageListener>>()

private var cleanupFuture: ScheduledFuture<*>? = null

override fun triggerHandleMessage() {
Expand All @@ -64,40 +67,68 @@ internal class IncomingMsgControllerImpl(

@Synchronized
override fun registerListener(topic: String, listener: MessageListener) {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) + listener
if (topic.isWildCardTopic()) {
wildcardTopicListenerMap[topic] = (wildcardTopicListenerMap[topic] ?: emptyList()) + listener
} else {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) + listener
}
triggerHandleMessage()
}

@Synchronized
override fun unregisterListener(topic: String, listener: MessageListener) {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) - listener
if (listenerMap[topic]!!.isEmpty()) {
listenerMap.remove(topic)
if (topic.isWildCardTopic()) {
wildcardTopicListenerMap[topic] = (wildcardTopicListenerMap[topic] ?: emptyList()) - listener
if (wildcardTopicListenerMap[topic]!!.isEmpty()) {
wildcardTopicListenerMap.remove(topic)
}
} else {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) - listener
if (listenerMap[topic]!!.isEmpty()) {
listenerMap.remove(topic)
}
}
}

private inner class HandleMessage : Runnable {
override fun run() {
try {
if (listenerMap.keys.isEmpty()) {
if (listenerMap.keys.isEmpty() && wildcardTopicListenerMap.isEmpty()) {
logger.d(TAG, "No listeners registered")
return
}
val messages: List<MqttReceivePacket> =
mqttReceivePersistence.getAllIncomingMessagesWithTopicFilter(listenerMap.keys)
if (mqttUtils.isEmpty(messages)) {
logger.d(TAG, "No Messages in Table")
return
}
val deletedMsgIds = mutableListOf<Long>()
for (message in messages) {
logger.d(TAG, "Going to process ${message.messageId}")
val listenersNotified = notifyListeners(message)
val listenersNotified = notifyListeners(message, listenerMap[message.topic]!!)
if (listenersNotified) {
deletedMsgIds.add(message.messageId)
}
logger.d(TAG, "Successfully Processed Message ${message.messageId}")
}
// processing messages for wildcard topic subscription
for (wildCardTopic in wildcardTopicListenerMap.keys()) {
val topicForDBQuery = parseWildCardTopicForDBQuery(wildCardTopic)
val wildcardMessages: List<MqttReceivePacket> =
mqttReceivePersistence.getAllIncomingMessagesForWildCardTopic(topicForDBQuery)
for (message in wildcardMessages) {
logger.d(TAG, "Going to process ${message.messageId}")
val wildCardTopicRegex = parseWildCardTopicForRegex(wildCardTopic)
if (wildCardTopicRegex.matches(message.topic)) {
logger.d(TAG, "Wildcard topic: $wildCardTopic matches ${message.topic}")
val listenersNotified =
notifyListeners(message, wildcardTopicListenerMap[wildCardTopic]!!)
if (listenersNotified) {
deletedMsgIds.add(message.messageId)
}
} else {
logger.d(TAG, "Wildcard topic: $wildCardTopic does not match ${message.topic}")
}
logger.d(TAG, "Successfully Processed Message ${message.messageId}")
}
}
if (deletedMsgIds.isNotEmpty()) {
val deletedMessagesCount = deleteMessages(deletedMsgIds)
logger.d(TAG, "Deleted $deletedMessagesCount messages")
Expand All @@ -112,6 +143,18 @@ internal class IncomingMsgControllerImpl(
}
}

private fun parseWildCardTopicForDBQuery(topic: String): String {
var updatedTopic: String = topic.replace("+", "%")
updatedTopic = updatedTopic.replace("#", "%")
return updatedTopic
}

private fun parseWildCardTopicForRegex(topic: String): Regex {
var updatedTopic: String = topic.replace("+", "[^\\/]+")
updatedTopic = updatedTopic.replace("#", "([^\\/]+(\\/?[^\\/])*)+")
return Regex(updatedTopic)
}

private inner class CleanupExpiredMessages : Runnable {
override fun run() {
logger.d(TAG, "Deleting expired messages")
Expand All @@ -123,10 +166,10 @@ internal class IncomingMsgControllerImpl(
}
}

private fun notifyListeners(message: MqttReceivePacket): Boolean {
private fun notifyListeners(message: MqttReceivePacket, listeners: List<MessageListener>): Boolean {
var notified = false
try {
listenerMap[message.topic]!!.forEach {
listeners.forEach {
notified = true
it.onMessageReceived(message.toMqttMessage())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ internal class MqttConnection(
),
timeTakenMillis = (clock.nanoTime() - subscribeStartTime).fromNanosToMillis()
)
subscriptionStore.getListener().onInvalidTopicsSubscribeFailure(topicMap)
}
}
}
Expand Down Expand Up @@ -546,6 +547,7 @@ internal class MqttConnection(
),
timeTakenMillis = (clock.nanoTime() - unsubscribeStartTime).fromNanosToMillis()
)
subscriptionStore.getListener().onInvalidTopicsUnsubscribeFailure(topics)
}
}
}
Expand Down Expand Up @@ -576,11 +578,12 @@ internal class MqttConnection(
connectionConfig.connectionEventHandler.onMqttSubscribeFailure(
topics = failTopicMap,
timeTakenMillis = (clock.nanoTime() - context.startTime).fromNanosToMillis(),
throwable = MqttException(MqttException.REASON_CODE_INVALID_SUBSCRIPTION.toInt())
throwable = MqttException(REASON_CODE_INVALID_SUBSCRIPTION.toInt())
)
}

subscriptionStore.getListener().onTopicsSubscribed(successTopicMap)
subscriptionStore.getListener().onInvalidTopicsSubscribeFailure(failTopicMap)
subscriptionPolicy.resetParams()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ internal interface IMqttReceivePersistence {
fun getAllIncomingMessagesWithTopicFilter(topics: Set<String>): List<MqttReceivePacket>
fun removeReceivedMessages(messageIds: List<Long>): Int
fun removeMessagesWithOlderTimestamp(timestampNanos: Long): Int
fun getAllIncomingMessagesForWildCardTopic(topic: String): List<MqttReceivePacket>
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ internal interface IncomingMessagesDao {
@Query("SELECT * from incoming_messages where topic in (:topics)")
fun getAllMessagesWithTopicFilter(topics: Set<String>): List<MqttReceivePacket>

@Query("SELECT * from incoming_messages where topic LIKE :topic")
fun getAllIncomingMessagesForWildCardTopic(topic: String): List<MqttReceivePacket>

@Query("DELETE from incoming_messages")
fun clearAllMessages()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ internal class PahoPersistence(private val context: Context) :
return incomingMessagesDao.getAllMessagesWithTopicFilter(topics)
}

override fun getAllIncomingMessagesForWildCardTopic(topic: String): List<MqttReceivePacket> {
return incomingMessagesDao.getAllIncomingMessagesForWildCardTopic(topic)
}

override fun removeReceivedMessages(messageIds: List<Long>): Int {
return incomingMessagesDao.removeMessagesById(messageIds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ import com.gojek.courier.QoS

internal class InMemorySubscriptionStore : SubscriptionStore {
private var state = State(mapOf())
private val listener = object : SubscriptionStoreListener {}
private val listener = object : SubscriptionStoreListener {
override fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) {
state = state.copy(
subscriptionTopics = state.subscriptionTopics - topicMap.keys
)
}
}

private data class State(val subscriptionTopics: Map<String, QoS>)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ internal class PersistableSubscriptionStore(context: Context) : SubscriptionStor
override fun onTopicsUnsubscribed(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}

override fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) {
state = state.copy(
subscriptionTopics = state.subscriptionTopics - topicMap.keys,
pendingUnsubscribeTopics = state.pendingUnsubscribeTopics
)
}

override fun onInvalidTopicsUnsubscribeFailure(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}
}

private data class State(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ internal class PersistableSubscriptionStoreV2(context: Context) : SubscriptionSt
override fun onTopicsUnsubscribed(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}

override fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) {
state = state.copy(
subscriptionTopics = state.subscriptionTopics - topicMap.keys,
pendingUnsubscribeTopics = state.pendingUnsubscribeTopics
)
}

override fun onInvalidTopicsUnsubscribeFailure(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}
}

private data class State(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ internal interface SubscriptionStore {

internal interface SubscriptionStoreListener {
fun onTopicsSubscribed(topicMap: Map<String, QoS>) = Unit
fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) = Unit
fun onTopicsUnsubscribed(topics: Set<String>) = Unit
fun onInvalidTopicsUnsubscribeFailure(topics: Set<String>) = Unit
}