Skip to content

Commit

Permalink
Merge pull request #44323 from mkouba/ws-next-session-context-optimiz…
Browse files Browse the repository at this point in the history
…ation

WebSockets Next: activate CDI session context only if needed
  • Loading branch information
mkouba authored Nov 5, 2024
2 parents cb47d8f + e14989f commit 2245430
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 49 deletions.
39 changes: 28 additions & 11 deletions docs/src/main/asciidoc/websockets-next-reference.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ However, developers can specify alternative scopes to fit their specific require
`@Singleton` and `@ApplicationScoped` endpoints are shared across all WebSocket connections.
Therefore, implementations should be either stateless or thread-safe.

==== Session context

If an endpoint is annotated with `@SessionScoped`, or depends directly or indirectly on a `@SessionScoped` bean, then each WebSocket connection is associated with its own _session context_.
The session context is active during endpoint callback invocation.
Subsequent invocations of <<callback-methods>> within the same connection utilize the same session context.
The session context remains active until the connection is closed (usually when the `@OnClose` method completes execution), at which point it is terminated.

TIP: It is also possible to set the `quarkus.websockets-next.server.activate-session-context` config property to `always`. In this case, the session context is always activated, no matter if a `@SessionScoped` bean participates in the dependency tree.

.`@SessionScoped` Endpoint
[source,java]
----
import jakarta.enterprise.context.SessionScoped;
Expand All @@ -172,20 +182,27 @@ public class MyWebSocket {
}
----
<1> This server endpoint is not shared and is scoped to the session.
<1> This server endpoint is not shared and is scoped to the session/connection.

==== Request context

If an endpoint is annotated with `@RequestScoped`, or with a security annotation (such as `@RolesAllowed`), or depends directly or indirectly on a `@RequestScoped` bean, or on a bean annotated with a security annotation, then each WebSocket endpoint callback method execution is associated with a new _request context_.
The request context is active during endpoint callback invocation.

TIP: It is also possible to set the `quarkus.websockets-next.server.activate-request-context` config property to `always`. In this case, the request context is always activated when an endpoint callback is invoked.

Each WebSocket connection is associated with its own _session_ context.
When the `@OnOpen` method is invoked, a session context corresponding to the WebSocket connection is created.
Subsequent calls to `@On[Text|Binary]Message` or `@OnClose` methods utilize this same session context.
The session context remains active until the `@OnClose` method completes execution, at which point it is terminated.
.`@RequestScoped` Endpoint
[source,java]
----
import jakarta.enterprise.context.RequestScoped;
In cases where a WebSocket endpoint does not declare an `@OnOpen` method, the session context is still created.
It remains active until the connection terminates, regardless of the presence of an `@OnClose` method.
@WebSocket(path = "/ws")
@RequestScoped <1>
public class MyWebSocket {
Endpoint callbacks may also have the request context activated for the duration of the method execution (until it produced its result).
By default, the request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated with a security annotation (such as `@RolesAllowed`) in the dependency tree of the endpoint.
However, it is possible to set the `quarkus.websockets-next.server.activate-request-context` config property to `always`.
In this case, the request context is always activated when an endpoint callback is invoked.
}
----
<1> This server endpoint is instantiated for each callback method execution.

[[callback-methods]]
=== Callback methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import io.quarkus.arc.processor.InvokerInfo;
import io.quarkus.arc.processor.KotlinDotNames;
import io.quarkus.arc.processor.KotlinUtils;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.arc.processor.Types;
import io.quarkus.bootstrap.classloading.QuarkusClassLoader;
import io.quarkus.deployment.Capabilities;
Expand Down Expand Up @@ -234,7 +235,7 @@ ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuil

@BuildStep
CustomScopeBuildItem registerSessionScope() {
return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class.getName()));
return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class));
}

@BuildStep
Expand Down Expand Up @@ -466,18 +467,23 @@ public void registerRoutes(WebSocketServerRecorder recorder, List<WebSocketEndpo
.displayOnNotFoundPage("WebSocket Endpoint")
.handlerType(HandlerType.NORMAL)
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId,
activateRequestContext(config, endpoint.endpointId, endpoints, validationPhase.getBeanResolver()),
activateContext(config.activateRequestContext(), BuiltinScope.REQUEST.getInfo(),
endpoint.endpointId, endpoints, validationPhase.getBeanResolver()),
activateContext(config.activateSessionContext(),
new ScopeInfo(DotName.createSimple(SessionScoped.class), true), endpoint.endpointId,
endpoints, validationPhase.getBeanResolver()),
endpoint.path));
routes.produce(builder.build());
}
}

private boolean activateRequestContext(WebSocketsServerBuildConfig config, String endpointId,
private boolean activateContext(WebSocketsServerBuildConfig.ContextActivation activation, ScopeInfo scope,
String endpointId,
List<WebSocketEndpointBuildItem> endpoints, BeanResolver beanResolver) {
return switch (config.activateRequestContext()) {
return switch (activation) {
case ALWAYS -> true;
case AUTO -> needsRequestContext(findEndpoint(endpointId, endpoints).bean, new HashSet<>(), beanResolver);
default -> throw new IllegalArgumentException("Unexpected value: " + config.activateRequestContext());
case AUTO -> needsContext(findEndpoint(endpointId, endpoints).bean, scope, new HashSet<>(), beanResolver);
default -> throw new IllegalArgumentException("Unexpected value: " + activation);
};
}

Expand All @@ -490,21 +496,23 @@ private WebSocketEndpointBuildItem findEndpoint(String endpointId, List<WebSocke
throw new IllegalArgumentException("Endpoint not found: " + endpointId);
}

private boolean needsRequestContext(BeanInfo bean, Set<String> processedBeans, BeanResolver beanResolver) {
private boolean needsContext(BeanInfo bean, ScopeInfo scope, Set<String> processedBeans, BeanResolver beanResolver) {
if (processedBeans.add(bean.getIdentifier())) {
if (BuiltinScope.REQUEST.is(bean.getScope())
|| (bean.isClassBean()
&& bean.hasAroundInvokeInterceptors()
&& SecurityTransformerUtils.hasSecurityAnnotation(bean.getTarget().get().asClass()))) {
// Bean is:
// 1. Request scoped, or
// 2. Is class-based, has an aroundInvoke interceptor associated and is annotated with a security annotation

if (scope.equals(bean.getScope())) {
// Bean has the given scope
return true;
} else if (BuiltinScope.REQUEST.is(scope)
&& bean.isClassBean()
&& bean.hasAroundInvokeInterceptors()
&& SecurityTransformerUtils.hasSecurityAnnotation(bean.getTarget().get().asClass())) {
// The given scope is RequestScoped, the bean is class-based, has an aroundInvoke interceptor associated and is annotated with a security annotation
return true;
}
for (InjectionPointInfo injectionPoint : bean.getAllInjectionPoints()) {
BeanInfo dependency = injectionPoint.getResolvedBean();
if (dependency != null) {
if (needsRequestContext(dependency, processedBeans, beanResolver)) {
if (needsContext(dependency, scope, processedBeans, beanResolver)) {
return true;
}
} else {
Expand All @@ -525,7 +533,7 @@ private boolean needsRequestContext(BeanInfo bean, Set<String> processedBeans, B
if (requiredType != null) {
// For programmatic lookup and @All List<> we need to resolve the beans manually
for (BeanInfo lookupDependency : beanResolver.resolveBeans(requiredType, qualifiers)) {
if (needsRequestContext(lookupDependency, processedBeans, beanResolver)) {
if (needsContext(lookupDependency, scope, processedBeans, beanResolver)) {
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.quarkus.websockets.next.test.requestcontext;

import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.RequestScoped;
import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;

public class RequestScopedEndpointTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Echo.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("echo")
URI echo;

@Test
void testRequestScopedEndpoint() throws InterruptedException {
try (WSClient client = WSClient.create(vertx).connect(echo)) {
client.send("foo");
client.send("bar");
client.send("baz");
client.waitForMessages(3);
assertTrue(Echo.DESTROYED_LATCH.await(5, TimeUnit.SECONDS),
"Latch count: " + Echo.DESTROYED_LATCH.getCount());
}
}

@RequestScoped
@WebSocket(path = "/echo")
public static class Echo {

static final CountDownLatch DESTROYED_LATCH = new CountDownLatch(3);

@OnTextMessage
String echo(String message) {
if (!Arc.container().requestContext().isActive()) {
throw new IllegalStateException();
}
return message;
}

@PreDestroy
void destroy() {
DESTROYED_LATCH.countDown();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class SubprotocolNotAvailableTest {
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
});
}).overrideConfigKey("quarkus.websockets-next.server.activate-session-context", "always");

@Inject
Vertx vertx;
Expand All @@ -44,7 +44,12 @@ public class SubprotocolNotAvailableTest {
@Test
void testConnectionRejected() {
CompletionException e = assertThrows(CompletionException.class,
() -> new WSClient(vertx).connect(new WebSocketConnectOptions().addSubProtocol("oak"), endUri));
() -> {
try (WSClient connect = new WSClient(vertx).connect(new WebSocketConnectOptions().addSubProtocol("oak"),
endUri)) {
// handshake should fail
}
});
Throwable cause = e.getCause();
assertTrue(cause instanceof WebSocketClientHandshakeException);
assertFalse(Endpoint.OPEN_CALLED.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,28 @@
public interface WebSocketsServerBuildConfig {

/**
* Specifies whether to activate the CDI request context when an endpoint callback is invoked. By default, the request
* context is only activated if needed.
* Specifies the activation strategy for the CDI request context during endpoint callback invocation. By default, the
* request context is only activated if needed, i.e. if there is a bean with the given scope, or a bean annotated
* with a security annotation (such as {@code @RolesAllowed}), in the dependency tree of the endpoint.
*/
@WithDefault("auto")
RequestContextActivation activateRequestContext();
ContextActivation activateRequestContext();

enum RequestContextActivation {
/**
* Specifies the activation strategy for the CDI session context during endpoint callback invocation. By default, the
* session context is only activated if needed, i.e. if there is a bean with the given scope in the dependency tree of the
* endpoint.
*/
@WithDefault("auto")
ContextActivation activateSessionContext();

enum ContextActivation {
/**
* The request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated
* with a security annotation (such as {@code @RolesAllowed}) in the dependency tree of the endpoint.
* The context is only activated if needed.
*/
AUTO,
/**
* The request context is always activated.
* The context is always activated.
*/
ALWAYS
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkus.websockets.next.runtime;

import java.util.Objects;

import org.jboss.logging.Logger;

import io.quarkus.arc.InjectableContext.ContextState;
Expand All @@ -9,6 +11,9 @@
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Context;

/**
* Per-endpoint CDI context support.
*/
public class ContextSupport {

private static final Logger LOG = Logger.getLogger(ContextSupport.class);
Expand All @@ -24,9 +29,9 @@ public class ContextSupport {
WebSocketSessionContext sessionContext,
ManagedContext requestContext) {
this.connection = connection;
this.sessionContextState = sessionContextState;
this.sessionContext = sessionContext;
this.requestContext = requestContext;
this.sessionContextState = sessionContext != null ? Objects.requireNonNull(sessionContextState) : null;
}

void start() {
Expand All @@ -42,8 +47,10 @@ void start(ContextState requestContextState) {
}

void startSession() {
// Activate the captured session context
sessionContext.activate(sessionContextState);
if (sessionContext != null) {
// Activate the captured session context
sessionContext.activate(sessionContextState);
}
}

void end(boolean terminateSession) {
Expand All @@ -63,13 +70,15 @@ void end(boolean terminateRequest, boolean terminateSession) {
if (terminateSession) {
// OnClose - terminate the session context
endSession();
} else {
} else if (sessionContext != null) {
sessionContext.deactivate();
}
}

void endSession() {
sessionContext.terminate();
if (sessionContext != null) {
sessionContext.terminate();
}
}

static Context createNewDuplicatedContext(Context context, WebSocketConnectionBase connection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,21 @@ class Endpoints {
static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSocketConnectionBase connection,
WebSocketBase ws, String generatedEndpointClass, Optional<Duration> autoPingInterval,
SecuritySupport securitySupport, UnhandledFailureStrategy unhandledFailureStrategy, TrafficLogger trafficLogger,
Runnable onClose, boolean activateRequestContext, TelemetrySupport telemetrySupport) {
Runnable onClose, boolean activateRequestContext, boolean activateSessionContext,
TelemetrySupport telemetrySupport) {

Context context = vertx.getOrCreateContext();

// Initialize and capture the session context state that will be activated
// during message processing
WebSocketSessionContext sessionContext = sessionContext(container);
SessionContextState sessionContextState = sessionContext.initializeContextState();
WebSocketSessionContext sessionContext = null;
SessionContextState sessionContextState = null;
if (activateSessionContext) {
sessionContext = sessionContext(container);
sessionContextState = sessionContext.initializeContextState();
}
ContextSupport contextSupport = new ContextSupport(connection, sessionContextState,
sessionContext(container),
activateRequestContext ? container.requestContext() : null);
sessionContext, activateRequestContext ? container.requestContext() : null);

// Create an endpoint that delegates callbacks to the endpoint bean
WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, contextSupport,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public void handle(AsyncResult<WebSocket> r) {
() -> {
connectionManager.remove(clientEndpoint.generatedEndpointClass, connection);
client.get().close();
}, true, telemetrySupport);
}, true, true, telemetrySupport);

return connection;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public Object get() {
}

public Handler<RoutingContext> createEndpointHandler(String generatedEndpointClass, String endpointId,
boolean activateRequestContext, String endpointPath) {
boolean activateRequestContext, boolean activateSessionContext, String endpointPath) {
ArcContainer container = Arc.container();
ConnectionManager connectionManager = container.instance(ConnectionManager.class).get();
Codecs codecs = container.instance(Codecs.class).get();
Expand Down Expand Up @@ -125,7 +125,7 @@ public void handle(Throwable throwable) {
Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass,
config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(), trafficLogger,
() -> connectionManager.remove(generatedEndpointClass, connection), activateRequestContext,
telemetrySupport);
activateSessionContext, telemetrySupport);
});
}

Expand Down

0 comments on commit 2245430

Please sign in to comment.