diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java index 1517392ed36..2ddc0c0d241 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/AbstractMessageChannel.java @@ -17,6 +17,7 @@ package org.springframework.integration.channel; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; @@ -164,8 +165,9 @@ public void setDatatypes(Class... datatypes) { */ @Override public void setInterceptors(List interceptors) { - interceptors.sort(this.orderComparator); - this.interceptors.set(interceptors); + List interceptorsToUse = new ArrayList<>(interceptors); + interceptorsToUse.sort(this.orderComparator); + this.interceptors.set(interceptorsToUse); } /** diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java index e8ca1124559..19408c21249 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptor.java @@ -16,6 +16,9 @@ package org.springframework.integration.channel.interceptor; +import java.util.LinkedList; +import java.util.Queue; + import io.micrometer.common.lang.Nullable; import org.springframework.integration.support.MessageDecorator; @@ -58,20 +61,27 @@ public abstract class ThreadStatePropagationChannelInterceptor implements Exe public final Message preSend(Message message, MessageChannel channel) { S threadContext = obtainPropagatingContext(message, channel); if (threadContext != null) { - return new MessageWithThreadState<>(message, threadContext); - } - else { - return message; + if (message instanceof MessageWithThreadState messageWithThreadState) { + messageWithThreadState.stateQueue.add(threadContext); + } + else { + return new MessageWithThreadState(message, threadContext); + } } + + return message; } @Override @SuppressWarnings("unchecked") public final Message postReceive(Message message, MessageChannel channel) { - if (message instanceof MessageWithThreadState) { - MessageWithThreadState messageWithThreadState = (MessageWithThreadState) message; - Message messageToHandle = messageWithThreadState.message; - populatePropagatedContext(messageWithThreadState.state, messageToHandle, channel); + if (message instanceof MessageWithThreadState messageWithThreadState) { + Object threadContext = messageWithThreadState.stateQueue.poll(); + Message messageToHandle = messageWithThreadState; + if (messageWithThreadState.stateQueue.isEmpty()) { + messageToHandle = messageWithThreadState.message; + } + populatePropagatedContext((S) threadContext, messageToHandle, channel); return messageToHandle; } return message; @@ -88,16 +98,21 @@ public final Message beforeHandle(Message message, MessageChannel channel, protected abstract void populatePropagatedContext(@Nullable S state, Message message, MessageChannel channel); - private static final class MessageWithThreadState implements Message, MessageDecorator { + private static final class MessageWithThreadState implements Message, MessageDecorator { private final Message message; - private final S state; + private final Queue stateQueue; + + MessageWithThreadState(Message message, Object state) { + this(message, new LinkedList<>()); + this.stateQueue.add(state); + } @SuppressWarnings("unchecked") - MessageWithThreadState(Message message, S state) { + private MessageWithThreadState(Message message, Queue stateQueue) { this.message = (Message) message; - this.state = state; + this.stateQueue = new LinkedList<>(stateQueue); } @Override @@ -112,14 +127,14 @@ public MessageHeaders getHeaders() { @Override public Message decorateMessage(Message message) { - return new MessageWithThreadState<>(message, this.state); + return new MessageWithThreadState(message, this.stateQueue); } @Override public String toString() { return "MessageWithThreadState{" + "message=" + this.message + - ", state=" + this.state + + ", state=" + this.stateQueue + '}'; } diff --git a/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java b/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java index ec843b2444c..43dec106abb 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/dispatcher/BroadcastingDispatcher.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ package org.springframework.integration.dispatcher; import java.util.Collection; -import java.util.UUID; import java.util.concurrent.Executor; import org.springframework.beans.BeansException; @@ -57,13 +56,13 @@ public class BroadcastingDispatcher extends AbstractDispatcher implements BeanFa private final boolean requireSubscribers; - private volatile boolean ignoreFailures; + private final Executor executor; - private volatile boolean applySequence; + private boolean ignoreFailures; - private final Executor executor; + private boolean applySequence; - private volatile int minSubscribers; + private int minSubscribers; private MessageHandlingTaskDecorator messageHandlingTaskDecorator = task -> task; @@ -149,24 +148,20 @@ public boolean dispatch(Message message) { int dispatched = 0; int sequenceNumber = 1; Collection handlers = this.getHandlers(); - if (this.requireSubscribers && handlers.size() == 0) { + if (this.requireSubscribers && handlers.isEmpty()) { throw new MessageDispatchingException(message, "Dispatcher has no subscribers"); } int sequenceSize = handlers.size(); Message messageToSend = message; - UUID sequenceId = null; - if (this.applySequence) { - sequenceId = message.getHeaders().getId(); - } for (MessageHandler handler : handlers) { if (this.applySequence) { messageToSend = getMessageBuilderFactory() .fromMessage(message) - .pushSequenceDetails(sequenceId, sequenceNumber++, sequenceSize) + .pushSequenceDetails(message.getHeaders().getId(), sequenceNumber++, sequenceSize) .build(); - if (message instanceof MessageDecorator) { - messageToSend = ((MessageDecorator) message).decorateMessage(messageToSend); - } + } + if (message instanceof MessageDecorator messageDecorator) { + messageToSend = messageDecorator.decorateMessage(messageToSend); } if (this.executor != null) { @@ -175,7 +170,7 @@ public boolean dispatch(Message message) { dispatched++; } else { - if (this.invokeHandler(handler, messageToSend)) { + if (invokeHandler(handler, messageToSend)) { dispatched++; } } @@ -222,15 +217,15 @@ private boolean invokeHandler(MessageHandler handler, Message message) { handler.handleMessage(message); return true; } - catch (RuntimeException e) { + catch (RuntimeException ex) { if (!this.ignoreFailures) { - if (e instanceof MessagingException && ((MessagingException) e).getFailedMessage() == null) { // NOSONAR - throw new MessagingException(message, "Failed to handle Message", e); + if (ex instanceof MessagingException exception && exception.getFailedMessage() == null) { // NOSONAR + throw new MessagingException(message, "Failed to handle Message", ex); } - throw e; + throw ex; } - else if (this.logger.isWarnEnabled()) { - logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", e); + else { + logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", ex); } return false; } diff --git a/spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java b/spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java new file mode 100644 index 00000000000..5d48a0271a1 --- /dev/null +++ b/spring-integration-core/src/test/java/org/springframework/integration/channel/interceptor/ThreadStatePropagationChannelInterceptorTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.integration.channel.interceptor; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.task.SyncTaskExecutor; +import org.springframework.integration.channel.ExecutorChannel; +import org.springframework.integration.util.ErrorHandlingTaskExecutor; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.support.GenericMessage; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * @author Artem Bilan + * + * @since 6.2 + */ +public class ThreadStatePropagationChannelInterceptorTests { + + @Test + void ThreadStatePropagationChannelInterceptorsCanBeStacked() { + TestContext1 ctx1 = new TestContext1(); + TestContext2 ctx2 = new TestContext2(); + + List propagatedContexts = new ArrayList<>(); + + var interceptor1 = new ThreadStatePropagationChannelInterceptor() { + @Override + protected TestContext1 obtainPropagatingContext(Message message, MessageChannel channel) { + return ctx1; + } + + @Override + protected void populatePropagatedContext(TestContext1 state, Message message, MessageChannel channel) { + propagatedContexts.add(state); + } + + }; + + var interceptor2 = new ThreadStatePropagationChannelInterceptor() { + @Override + protected TestContext2 obtainPropagatingContext(Message message, MessageChannel channel) { + return ctx2; + } + + @Override + protected void populatePropagatedContext(TestContext2 state, Message message, MessageChannel channel) { + propagatedContexts.add(state); + } + + }; + + ExecutorChannel testChannel = new ExecutorChannel( + new ErrorHandlingTaskExecutor(new SyncTaskExecutor(), ReflectionUtils::rethrowRuntimeException)); + testChannel.setInterceptors(List.of(interceptor1, interceptor2)); + testChannel.setBeanFactory(mock()); + testChannel.afterPropertiesSet(); + testChannel.subscribe(m -> { + }); + + testChannel.send(new GenericMessage<>("test data")); + + assertThat(propagatedContexts.get(0)).isEqualTo(ctx1); + assertThat(propagatedContexts.get(1)).isEqualTo(ctx2); + } + + private record TestContext1() { + } + + private record TestContext2() { + } + +}