Skip to content

Commit

Permalink
ARTEMIS-5184 STOMP noLocal is scoped to session not subscription
Browse files Browse the repository at this point in the history
This closes #5414
  • Loading branch information
jbertram authored and gemmellr committed Jan 9, 2025
1 parent 70c84ad commit 09b8f67
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ protected void sendServerMessage(ICoreMessage message, String txID) throws Activ
try {
StompSession stompSession = getSession(txID);

if (stompSession.isNoLocal()) {
// only set the connection ID property if we have a noLocal subscription
if (stompSession.getNoLocalSubscriptionCount() > 0) {
message.putStringProperty(CONNECTION_ID_PROPERTY_NAME_STRING, getID().toString());
}
if (isEnableMessageID()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -362,13 +363,12 @@ public StompPostReceiptFunction subscribe(StompConnection connection,
boolean noLocal,
Integer consumerWindowSize) throws Exception {
StompSession stompSession = getSession(connection);
stompSession.setNoLocal(noLocal);
if (stompSession.containsSubscription(subscriptionID)) {
throw new ActiveMQStompException(connection, "There already is a subscription for: " + subscriptionID +
". Either use unique subscription IDs or do not create multiple subscriptions for the same destination");
}
long consumerID = server.getStorageManager().generateID();
return stompSession.addSubscription(consumerID, subscriptionID, connection.getClientID(), durableSubscriptionName, destination, selector, ack, consumerWindowSize);
return stompSession.addSubscription(consumerID, subscriptionID, connection.getClientID(), durableSubscriptionName, destination, selector, ack, noLocal, consumerWindowSize);
}

public void unsubscribe(StompConnection connection,
Expand Down Expand Up @@ -407,4 +407,8 @@ public boolean destinationExists(String destination) {
public ActiveMQServer getServer() {
return server;
}

public Collection<StompSession> getSessions() {
return sessions.values();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.channel.EventLoop;
import org.apache.activemq.artemis.api.core.ActiveMQBuffer;
Expand Down Expand Up @@ -82,7 +83,7 @@ public class StompSession implements SessionCallback {
// key = consumer ID and message ID, value = frame length
private final Map<Pair<Long, Long>, Integer> messagesToAck = new ConcurrentHashMap<>();

private volatile boolean noLocal = false;
private AtomicInteger noLocalSubscriptionCount = new AtomicInteger(0);

private boolean txPending = false;

Expand Down Expand Up @@ -231,6 +232,9 @@ public void closed() {
public void disconnect(ServerConsumer consumerId, String errorDescription) {
StompSubscription stompSubscription = subscriptions.remove(consumerId.getID());
if (stompSubscription != null) {
if (stompSubscription.isNoLocal()) {
noLocalSubscriptionCount.decrementAndGet();
}
StompFrame frame = connection.getFrameHandler().createStompFrame(Stomp.Responses.ERROR);
frame.addHeader(Stomp.Headers.CONTENT_TYPE, "text/plain");
frame.setBody("consumer with ID " + consumerId + " disconnected by server");
Expand Down Expand Up @@ -306,6 +310,7 @@ public StompPostReceiptFunction addSubscription(long consumerID,
String destination,
String selector,
String ack,
boolean noLocal,
Integer consumerWindowSize) throws Exception {
SimpleString address = SimpleString.of(destination);
SimpleString queueName = SimpleString.of(destination);
Expand Down Expand Up @@ -342,8 +347,11 @@ public StompPostReceiptFunction addSubscription(long consumerID,
session.createQueue(QueueConfiguration.of(queueName).setAddress(address).setFilterString(selectorSimple).setDurable(false).setTemporary(true));
}
}
if (noLocal) {
noLocalSubscriptionCount.incrementAndGet();
}
final ServerConsumer consumer = session.createConsumer(consumerID, queueName, multicast ? null : selectorSimple, false, false, 0);
StompSubscription subscription = new StompSubscription(subscriptionID, ack, queueName, multicast, finalConsumerWindowSize);
StompSubscription subscription = new StompSubscription(subscriptionID, ack, queueName, multicast, noLocal, finalConsumerWindowSize);
subscriptions.put(consumerID, subscription);
session.start();
/*
Expand All @@ -363,6 +371,9 @@ public boolean unsubscribe(String id, String durableSubscriptionName, String cli
StompSubscription sub = entry.getValue();
if (id != null && id.equals(sub.getID())) {
iterator.remove();
if (sub.isNoLocal()) {
noLocalSubscriptionCount.decrementAndGet();
}
SimpleString queueName = sub.getQueueName();
session.closeConsumer(consumerID);
Queue queue = manager.getServer().locateQueue(queueName);
Expand Down Expand Up @@ -402,12 +413,12 @@ public OperationContext getContext() {
return sessionContext;
}

public boolean isNoLocal() {
return noLocal;
public int getNoLocalSubscriptionCount() {
return noLocalSubscriptionCount.get();
}

public void setNoLocal(boolean noLocal) {
this.noLocal = noLocal;
public int getSubscriptionCount() {
return subscriptions.size();
}

public void sendInternal(Message message, boolean direct) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ public class StompSubscription {
// whether or not this subscription follows multicast semantics (e.g. for a JMS topic)
private final boolean multicast;

private final boolean noLocal;

private final int consumerWindowSize;

public StompSubscription(String subID, String ack, SimpleString queueName, boolean multicast, int consumerWindowSize) {
public StompSubscription(String subID, String ack, SimpleString queueName, boolean multicast, boolean noLocal, int consumerWindowSize) {
this.subID = subID;
this.ack = ack;
this.queueName = queueName;
this.multicast = multicast;
this.noLocal = noLocal;
this.consumerWindowSize = consumerWindowSize;
}

Expand All @@ -55,13 +58,16 @@ public boolean isMulticast() {
return multicast;
}

public boolean isNoLocal() {
return noLocal;
}

public int getConsumerWindowSize() {
return consumerWindowSize;
}

@Override
public String toString() {
return "StompSubscription[id=" + subID + ", ack=" + ack + ", queueName=" + queueName + ", multicast=" + multicast + ", consumerWindowSize=" + consumerWindowSize + "]";
return "StompSubscription[id=" + subID + ", ack=" + ack + ", queueName=" + queueName + ", multicast=" + multicast + ", noLocal=" + noLocal + ", consumerWindowSize=" + consumerWindowSize + "]";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.activemq.artemis.core.protocol.stomp.Stomp;
import org.apache.activemq.artemis.core.protocol.stomp.StompProtocolManager;
import org.apache.activemq.artemis.core.protocol.stomp.StompProtocolManagerFactory;
import org.apache.activemq.artemis.core.protocol.stomp.StompSession;
import org.apache.activemq.artemis.core.server.ActiveMQServer;
import org.apache.activemq.artemis.core.server.Queue;
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl;
Expand All @@ -65,7 +66,6 @@
import org.apache.activemq.artemis.json.JsonObject;
import org.apache.activemq.artemis.logs.AssertionLoggerHandler;
import org.apache.activemq.artemis.reader.MessageUtil;
import org.apache.activemq.artemis.spi.core.remoting.Acceptor;
import org.apache.activemq.artemis.tests.integration.mqtt.FuseMQTTClientProvider;
import org.apache.activemq.artemis.tests.integration.mqtt.MQTTClientProvider;
import org.apache.activemq.artemis.tests.integration.stomp.util.ClientStompFrame;
Expand Down Expand Up @@ -754,9 +754,7 @@ public void testTransactedSessionLeak() throws Exception {

Wait.assertEquals(0, () -> server.getSessions().size(), 1000, 100);

Acceptor stompAcceptor = server.getRemotingService().getAcceptors().get("stomp");
StompProtocolManager stompProtocolManager = (StompProtocolManager) stompAcceptor.getProtocolHandler().getProtocolMap().get("STOMP");
assertNotNull(stompProtocolManager);
StompProtocolManager stompProtocolManager = getStompProtocolManager();

assertEquals(0, stompProtocolManager.getTransactedSessions().size());
}
Expand Down Expand Up @@ -1530,6 +1528,71 @@ public void testSubscribeToTopicWithNoLocalSendWithStomp() throws Exception {
}
}

@Test
public void testSubscribeToTopicWithNoLocalAndNormal() throws Exception {
conn.connect(defUser, defPass);
String noLocalSubscriptionId = RandomUtil.randomString();
String normalSubscriptionId = RandomUtil.randomString();
subscribeTopic(conn, noLocalSubscriptionId, null, null, true, true);
subscribeTopic(conn, normalSubscriptionId, null, null, true, false);

StompProtocolManager stompProtocolManager = getStompProtocolManager();
int totalSubCount = 0;
int noLocalSubCount = 0;
for (StompSession session : stompProtocolManager.getSessions()) {
totalSubCount += session.getSubscriptionCount();
noLocalSubCount += session.getNoLocalSubscriptionCount();
}
assertEquals(1, noLocalSubCount);
assertEquals(2, totalSubCount);

{ // Send a message on the same connection. It should be received by the normal subscription and not by the noLocal one.
send(conn, getTopicPrefix() + getTopicName(), null, "Hello World");

ClientStompFrame frame = conn.receiveFrame(100);
assertNotNull(frame);
assertEquals(Stomp.Responses.MESSAGE, frame.getCommand());
assertEquals(normalSubscriptionId, frame.getHeader(Stomp.Headers.Message.SUBSCRIPTION));
assertNotNull(frame.getHeader("__AMQ_CID"));
frame = conn.receiveFrame(100);
assertNull(frame);
}

unsubscribe(conn, noLocalSubscriptionId, true);

totalSubCount = 0;
noLocalSubCount = 0;
for (StompSession session : stompProtocolManager.getSessions()) {
totalSubCount += session.getSubscriptionCount();
noLocalSubCount += session.getNoLocalSubscriptionCount();
}
assertEquals(0, noLocalSubCount);
assertEquals(1, totalSubCount);

{ // Send another message on the same connection. It should be received by the normal subscription.
send(conn, getTopicPrefix() + getTopicName(), null, "Hello World");

ClientStompFrame frame = conn.receiveFrame(100);
assertNotNull(frame);
assertEquals(Stomp.Responses.MESSAGE, frame.getCommand());
assertEquals(normalSubscriptionId, frame.getHeader(Stomp.Headers.Message.SUBSCRIPTION));
assertNull(frame.getHeader("__AMQ_CID"));
}

unsubscribe(conn, normalSubscriptionId, true);

totalSubCount = 0;
noLocalSubCount = 0;
for (StompSession session : stompProtocolManager.getSessions()) {
totalSubCount += session.getSubscriptionCount();
noLocalSubCount += session.getNoLocalSubscriptionCount();
}
assertEquals(0, noLocalSubCount);
assertEquals(0, totalSubCount);

conn.disconnect();
}

@Test
public void testSubscribeToTopicWithNoLocalAndSelector() throws Exception {
conn.connect(defUser, defPass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.activemq.artemis.core.config.Configuration;
import org.apache.activemq.artemis.core.config.CoreAddressConfiguration;
import org.apache.activemq.artemis.core.protocol.stomp.Stomp;
import org.apache.activemq.artemis.core.protocol.stomp.StompProtocolManager;
import org.apache.activemq.artemis.core.protocol.stomp.StompProtocolManagerFactory;
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory;
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory;
Expand All @@ -54,6 +55,7 @@
import org.apache.activemq.artemis.core.server.ActiveMQServers;
import org.apache.activemq.artemis.jms.client.ActiveMQConnectionFactory;
import org.apache.activemq.artemis.jms.client.ActiveMQJMSConnectionFactory;
import org.apache.activemq.artemis.spi.core.remoting.Acceptor;
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager;
import org.apache.activemq.artemis.tests.integration.stomp.util.AbstractStompClientConnection;
import org.apache.activemq.artemis.tests.integration.stomp.util.ClientStompFrame;
Expand Down Expand Up @@ -657,4 +659,11 @@ public static ClientStompFrame send(StompClientConnection conn, String destinati
public static URI createStompClientUri(String scheme, String hostname, int port) throws URISyntaxException {
return new URI(scheme + "://" + hostname + ":" + port);
}

protected StompProtocolManager getStompProtocolManager() {
Acceptor stompAcceptor = server.getRemotingService().getAcceptors().get("stomp");
StompProtocolManager stompProtocolManager = (StompProtocolManager) stompAcceptor.getProtocolHandler().getProtocolMap().get("STOMP");
assertNotNull(stompProtocolManager);
return stompProtocolManager;
}
}

0 comments on commit 09b8f67

Please sign in to comment.