Skip to content

Commit

Permalink
Fix ThreadSPropagationChInterceptor for stacking (#8735)
Browse files Browse the repository at this point in the history
* Fix ThreadSPropagationChInterceptor for stacking

Related SO thread: https://stackoverflow.com/questions/77058188/multiple-threadstatepropagationchannelinterceptors-not-possible

The current `ThreadStatePropagationChannelInterceptor` logic is to wrap one
message to another (`MessageWithThreadState`), essentially stacking contexts.
The `postReceive()` logic is to unwrap a `MessageWithThreadState`,
therefore we deal with the latest pushed context which leads to the `ClassCastException`

* Rework `ThreadStatePropagationChannelInterceptor` logic to reuse existing `MessageWithThreadState`
and add the current context to its `stateQueue`.
Therefore, the `postReceive()` will `poll()` the oldest context which is, essentially,
the one populated by this interceptor before, according to the interceptors order
* Fix `AbstractMessageChannel.setInterceptors()` to not modify provided list of interceptors
* The new `ThreadStatePropagationChannelInterceptorTests` demonstrates the problem
described in that mentioned SO question and verifies that context are propagated
in the order they have been populated

**Cherry-pick to `6.1.x` & `6.0.x`**

* * Fix `ThreadStatePropagationChannelInterceptor` for publish-subscribe scenario.
Essentially, copy the state queue to a new decorated message
* Fix `BroadcastingDispatcher` to always decorate message, even if not `applySequence`

* * Fix unused import in the `BroadcastingDispatcher`

* * Fix unused import in the `ThreadStatePropagationChannelInterceptor`
  • Loading branch information
artembilan authored Sep 18, 2023
1 parent 0f13044 commit eafedaa
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.integration.channel;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -164,8 +165,9 @@ public void setDatatypes(Class<?>... datatypes) {
*/
@Override
public void setInterceptors(List<ChannelInterceptor> interceptors) {
interceptors.sort(this.orderComparator);
this.interceptors.set(interceptors);
List<ChannelInterceptor> interceptorsToUse = new ArrayList<>(interceptors);
interceptorsToUse.sort(this.orderComparator);
this.interceptors.set(interceptorsToUse);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package org.springframework.integration.channel.interceptor;

import java.util.LinkedList;
import java.util.Queue;

import io.micrometer.common.lang.Nullable;

import org.springframework.integration.support.MessageDecorator;
Expand Down Expand Up @@ -58,20 +61,27 @@ public abstract class ThreadStatePropagationChannelInterceptor<S> implements Exe
public final Message<?> preSend(Message<?> message, MessageChannel channel) {
S threadContext = obtainPropagatingContext(message, channel);
if (threadContext != null) {
return new MessageWithThreadState<>(message, threadContext);
}
else {
return message;
if (message instanceof MessageWithThreadState messageWithThreadState) {
messageWithThreadState.stateQueue.add(threadContext);
}
else {
return new MessageWithThreadState(message, threadContext);
}
}

return message;
}

@Override
@SuppressWarnings("unchecked")
public final Message<?> postReceive(Message<?> message, MessageChannel channel) {
if (message instanceof MessageWithThreadState) {
MessageWithThreadState<S> messageWithThreadState = (MessageWithThreadState<S>) message;
Message<?> messageToHandle = messageWithThreadState.message;
populatePropagatedContext(messageWithThreadState.state, messageToHandle, channel);
if (message instanceof MessageWithThreadState messageWithThreadState) {
Object threadContext = messageWithThreadState.stateQueue.poll();
Message<?> messageToHandle = messageWithThreadState;
if (messageWithThreadState.stateQueue.isEmpty()) {
messageToHandle = messageWithThreadState.message;
}
populatePropagatedContext((S) threadContext, messageToHandle, channel);
return messageToHandle;
}
return message;
Expand All @@ -88,16 +98,21 @@ public final Message<?> beforeHandle(Message<?> message, MessageChannel channel,
protected abstract void populatePropagatedContext(@Nullable S state, Message<?> message, MessageChannel channel);


private static final class MessageWithThreadState<S> implements Message<Object>, MessageDecorator {
private static final class MessageWithThreadState implements Message<Object>, MessageDecorator {

private final Message<Object> message;

private final S state;
private final Queue<Object> stateQueue;

MessageWithThreadState(Message<?> message, Object state) {
this(message, new LinkedList<>());
this.stateQueue.add(state);
}

@SuppressWarnings("unchecked")
MessageWithThreadState(Message<?> message, S state) {
private MessageWithThreadState(Message<?> message, Queue<Object> stateQueue) {
this.message = (Message<Object>) message;
this.state = state;
this.stateQueue = new LinkedList<>(stateQueue);
}

@Override
Expand All @@ -112,14 +127,14 @@ public MessageHeaders getHeaders() {

@Override
public Message<?> decorateMessage(Message<?> message) {
return new MessageWithThreadState<>(message, this.state);
return new MessageWithThreadState(message, this.stateQueue);
}

@Override
public String toString() {
return "MessageWithThreadState{" +
"message=" + this.message +
", state=" + this.state +
", state=" + this.stateQueue +
'}';
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,6 @@
package org.springframework.integration.dispatcher;

import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.Executor;

import org.springframework.beans.BeansException;
Expand Down Expand Up @@ -57,13 +56,13 @@ public class BroadcastingDispatcher extends AbstractDispatcher implements BeanFa

private final boolean requireSubscribers;

private volatile boolean ignoreFailures;
private final Executor executor;

private volatile boolean applySequence;
private boolean ignoreFailures;

private final Executor executor;
private boolean applySequence;

private volatile int minSubscribers;
private int minSubscribers;

private MessageHandlingTaskDecorator messageHandlingTaskDecorator = task -> task;

Expand Down Expand Up @@ -149,24 +148,20 @@ public boolean dispatch(Message<?> message) {
int dispatched = 0;
int sequenceNumber = 1;
Collection<MessageHandler> handlers = this.getHandlers();
if (this.requireSubscribers && handlers.size() == 0) {
if (this.requireSubscribers && handlers.isEmpty()) {
throw new MessageDispatchingException(message, "Dispatcher has no subscribers");
}
int sequenceSize = handlers.size();
Message<?> messageToSend = message;
UUID sequenceId = null;
if (this.applySequence) {
sequenceId = message.getHeaders().getId();
}
for (MessageHandler handler : handlers) {
if (this.applySequence) {
messageToSend = getMessageBuilderFactory()
.fromMessage(message)
.pushSequenceDetails(sequenceId, sequenceNumber++, sequenceSize)
.pushSequenceDetails(message.getHeaders().getId(), sequenceNumber++, sequenceSize)
.build();
if (message instanceof MessageDecorator) {
messageToSend = ((MessageDecorator) message).decorateMessage(messageToSend);
}
}
if (message instanceof MessageDecorator messageDecorator) {
messageToSend = messageDecorator.decorateMessage(messageToSend);
}

if (this.executor != null) {
Expand All @@ -175,7 +170,7 @@ public boolean dispatch(Message<?> message) {
dispatched++;
}
else {
if (this.invokeHandler(handler, messageToSend)) {
if (invokeHandler(handler, messageToSend)) {
dispatched++;
}
}
Expand Down Expand Up @@ -222,15 +217,15 @@ private boolean invokeHandler(MessageHandler handler, Message<?> message) {
handler.handleMessage(message);
return true;
}
catch (RuntimeException e) {
catch (RuntimeException ex) {
if (!this.ignoreFailures) {
if (e instanceof MessagingException && ((MessagingException) e).getFailedMessage() == null) { // NOSONAR
throw new MessagingException(message, "Failed to handle Message", e);
if (ex instanceof MessagingException exception && exception.getFailedMessage() == null) { // NOSONAR
throw new MessagingException(message, "Failed to handle Message", ex);
}
throw e;
throw ex;
}
else if (this.logger.isWarnEnabled()) {
logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", e);
else {
logger.warn("Suppressing Exception since 'ignoreFailures' is set to TRUE.", ex);
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.integration.channel.interceptor;

import java.util.ArrayList;
import java.util.List;

import org.junit.jupiter.api.Test;

import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.integration.channel.ExecutorChannel;
import org.springframework.integration.util.ErrorHandlingTaskExecutor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.util.ReflectionUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

/**
* @author Artem Bilan
*
* @since 6.2
*/
public class ThreadStatePropagationChannelInterceptorTests {

@Test
void ThreadStatePropagationChannelInterceptorsCanBeStacked() {
TestContext1 ctx1 = new TestContext1();
TestContext2 ctx2 = new TestContext2();

List<Object> propagatedContexts = new ArrayList<>();

var interceptor1 = new ThreadStatePropagationChannelInterceptor<TestContext1>() {
@Override
protected TestContext1 obtainPropagatingContext(Message<?> message, MessageChannel channel) {
return ctx1;
}

@Override
protected void populatePropagatedContext(TestContext1 state, Message<?> message, MessageChannel channel) {
propagatedContexts.add(state);
}

};

var interceptor2 = new ThreadStatePropagationChannelInterceptor<TestContext2>() {
@Override
protected TestContext2 obtainPropagatingContext(Message<?> message, MessageChannel channel) {
return ctx2;
}

@Override
protected void populatePropagatedContext(TestContext2 state, Message<?> message, MessageChannel channel) {
propagatedContexts.add(state);
}

};

ExecutorChannel testChannel = new ExecutorChannel(
new ErrorHandlingTaskExecutor(new SyncTaskExecutor(), ReflectionUtils::rethrowRuntimeException));
testChannel.setInterceptors(List.of(interceptor1, interceptor2));
testChannel.setBeanFactory(mock());
testChannel.afterPropertiesSet();
testChannel.subscribe(m -> {
});

testChannel.send(new GenericMessage<>("test data"));

assertThat(propagatedContexts.get(0)).isEqualTo(ctx1);
assertThat(propagatedContexts.get(1)).isEqualTo(ctx2);
}

private record TestContext1() {
}

private record TestContext2() {
}

}

0 comments on commit eafedaa

Please sign in to comment.