From e14989f0ef383ef4ff27379448c601bf42c70b00 Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Tue, 5 Nov 2024 14:03:34 +0100 Subject: [PATCH] WebSockets Next: activate CDI session context only if needed - follo-up to https://github.com/quarkusio/quarkus/pull/43915 - related to #39148 --- .../asciidoc/websockets-next-reference.adoc | 39 ++++++++--- .../next/deployment/WebSocketProcessor.java | 40 ++++++----- .../RequestScopedEndpointTest.java | 70 +++++++++++++++++++ .../SubprotocolNotAvailableTest.java | 9 ++- .../next/WebSocketsServerBuildConfig.java | 22 ++++-- .../next/runtime/ContextSupport.java | 19 +++-- .../websockets/next/runtime/Endpoints.java | 14 ++-- .../next/runtime/WebSocketConnectorImpl.java | 2 +- .../next/runtime/WebSocketServerRecorder.java | 4 +- 9 files changed, 170 insertions(+), 49 deletions(-) create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedEndpointTest.java diff --git a/docs/src/main/asciidoc/websockets-next-reference.adoc b/docs/src/main/asciidoc/websockets-next-reference.adoc index 7935bb9974358..e90a96fe2284c 100644 --- a/docs/src/main/asciidoc/websockets-next-reference.adoc +++ b/docs/src/main/asciidoc/websockets-next-reference.adoc @@ -162,6 +162,16 @@ However, developers can specify alternative scopes to fit their specific require `@Singleton` and `@ApplicationScoped` endpoints are shared across all WebSocket connections. Therefore, implementations should be either stateless or thread-safe. +==== Session context + +If an endpoint is annotated with `@SessionScoped`, or depends directly or indirectly on a `@SessionScoped` bean, then each WebSocket connection is associated with its own _session context_. +The session context is active during endpoint callback invocation. +Subsequent invocations of <> within the same connection utilize the same session context. +The session context remains active until the connection is closed (usually when the `@OnClose` method completes execution), at which point it is terminated. + +TIP: It is also possible to set the `quarkus.websockets-next.server.activate-session-context` config property to `always`. In this case, the session context is always activated, no matter if a `@SessionScoped` bean participates in the dependency tree. + +.`@SessionScoped` Endpoint [source,java] ---- import jakarta.enterprise.context.SessionScoped; @@ -172,20 +182,27 @@ public class MyWebSocket { } ---- -<1> This server endpoint is not shared and is scoped to the session. +<1> This server endpoint is not shared and is scoped to the session/connection. + +==== Request context + +If an endpoint is annotated with `@RequestScoped`, or with a security annotation (such as `@RolesAllowed`), or depends directly or indirectly on a `@RequestScoped` bean, or on a bean annotated with a security annotation, then each WebSocket endpoint callback method execution is associated with a new _request context_. +The request context is active during endpoint callback invocation. + +TIP: It is also possible to set the `quarkus.websockets-next.server.activate-request-context` config property to `always`. In this case, the request context is always activated when an endpoint callback is invoked. -Each WebSocket connection is associated with its own _session_ context. -When the `@OnOpen` method is invoked, a session context corresponding to the WebSocket connection is created. -Subsequent calls to `@On[Text|Binary]Message` or `@OnClose` methods utilize this same session context. -The session context remains active until the `@OnClose` method completes execution, at which point it is terminated. +.`@RequestScoped` Endpoint +[source,java] +---- +import jakarta.enterprise.context.RequestScoped; -In cases where a WebSocket endpoint does not declare an `@OnOpen` method, the session context is still created. -It remains active until the connection terminates, regardless of the presence of an `@OnClose` method. +@WebSocket(path = "/ws") +@RequestScoped <1> +public class MyWebSocket { -Endpoint callbacks may also have the request context activated for the duration of the method execution (until it produced its result). -By default, the request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated with a security annotation (such as `@RolesAllowed`) in the dependency tree of the endpoint. -However, it is possible to set the `quarkus.websockets-next.server.activate-request-context` config property to `always`. -In this case, the request context is always activated when an endpoint callback is invoked. +} +---- +<1> This server endpoint is instantiated for each callback method execution. [[callback-methods]] === Callback methods diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index dd07dd2ca390e..f32cb7327b77a 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -65,6 +65,7 @@ import io.quarkus.arc.processor.InvokerInfo; import io.quarkus.arc.processor.KotlinDotNames; import io.quarkus.arc.processor.KotlinUtils; +import io.quarkus.arc.processor.ScopeInfo; import io.quarkus.arc.processor.Types; import io.quarkus.bootstrap.classloading.QuarkusClassLoader; import io.quarkus.deployment.Capabilities; @@ -234,7 +235,7 @@ ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuil @BuildStep CustomScopeBuildItem registerSessionScope() { - return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class.getName())); + return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class)); } @BuildStep @@ -466,18 +467,23 @@ public void registerRoutes(WebSocketServerRecorder recorder, List endpoints, BeanResolver beanResolver) { - return switch (config.activateRequestContext()) { + return switch (activation) { case ALWAYS -> true; - case AUTO -> needsRequestContext(findEndpoint(endpointId, endpoints).bean, new HashSet<>(), beanResolver); - default -> throw new IllegalArgumentException("Unexpected value: " + config.activateRequestContext()); + case AUTO -> needsContext(findEndpoint(endpointId, endpoints).bean, scope, new HashSet<>(), beanResolver); + default -> throw new IllegalArgumentException("Unexpected value: " + activation); }; } @@ -490,21 +496,23 @@ private WebSocketEndpointBuildItem findEndpoint(String endpointId, List processedBeans, BeanResolver beanResolver) { + private boolean needsContext(BeanInfo bean, ScopeInfo scope, Set processedBeans, BeanResolver beanResolver) { if (processedBeans.add(bean.getIdentifier())) { - if (BuiltinScope.REQUEST.is(bean.getScope()) - || (bean.isClassBean() - && bean.hasAroundInvokeInterceptors() - && SecurityTransformerUtils.hasSecurityAnnotation(bean.getTarget().get().asClass()))) { - // Bean is: - // 1. Request scoped, or - // 2. Is class-based, has an aroundInvoke interceptor associated and is annotated with a security annotation + + if (scope.equals(bean.getScope())) { + // Bean has the given scope + return true; + } else if (BuiltinScope.REQUEST.is(scope) + && bean.isClassBean() + && bean.hasAroundInvokeInterceptors() + && SecurityTransformerUtils.hasSecurityAnnotation(bean.getTarget().get().asClass())) { + // The given scope is RequestScoped, the bean is class-based, has an aroundInvoke interceptor associated and is annotated with a security annotation return true; } for (InjectionPointInfo injectionPoint : bean.getAllInjectionPoints()) { BeanInfo dependency = injectionPoint.getResolvedBean(); if (dependency != null) { - if (needsRequestContext(dependency, processedBeans, beanResolver)) { + if (needsContext(dependency, scope, processedBeans, beanResolver)) { return true; } } else { @@ -525,7 +533,7 @@ private boolean needsRequestContext(BeanInfo bean, Set processedBeans, B if (requiredType != null) { // For programmatic lookup and @All List<> we need to resolve the beans manually for (BeanInfo lookupDependency : beanResolver.resolveBeans(requiredType, qualifiers)) { - if (needsRequestContext(lookupDependency, processedBeans, beanResolver)) { + if (needsContext(lookupDependency, scope, processedBeans, beanResolver)) { return true; } } diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedEndpointTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedEndpointTest.java new file mode 100644 index 0000000000000..b8d967cdc6c79 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/requestcontext/RequestScopedEndpointTest.java @@ -0,0 +1,70 @@ +package io.quarkus.websockets.next.test.requestcontext; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.RequestScoped; +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class RequestScopedEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Echo.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("echo") + URI echo; + + @Test + void testRequestScopedEndpoint() throws InterruptedException { + try (WSClient client = WSClient.create(vertx).connect(echo)) { + client.send("foo"); + client.send("bar"); + client.send("baz"); + client.waitForMessages(3); + assertTrue(Echo.DESTROYED_LATCH.await(5, TimeUnit.SECONDS), + "Latch count: " + Echo.DESTROYED_LATCH.getCount()); + } + } + + @RequestScoped + @WebSocket(path = "/echo") + public static class Echo { + + static final CountDownLatch DESTROYED_LATCH = new CountDownLatch(3); + + @OnTextMessage + String echo(String message) { + if (!Arc.container().requestContext().isActive()) { + throw new IllegalStateException(); + } + return message; + } + + @PreDestroy + void destroy() { + DESTROYED_LATCH.countDown(); + } + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolNotAvailableTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolNotAvailableTest.java index 9a79b8d12fda0..c862c5f068c20 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolNotAvailableTest.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolNotAvailableTest.java @@ -33,7 +33,7 @@ public class SubprotocolNotAvailableTest { public static final QuarkusUnitTest test = new QuarkusUnitTest() .withApplicationRoot(root -> { root.addClasses(Endpoint.class, WSClient.class); - }); + }).overrideConfigKey("quarkus.websockets-next.server.activate-session-context", "always"); @Inject Vertx vertx; @@ -44,7 +44,12 @@ public class SubprotocolNotAvailableTest { @Test void testConnectionRejected() { CompletionException e = assertThrows(CompletionException.class, - () -> new WSClient(vertx).connect(new WebSocketConnectOptions().addSubProtocol("oak"), endUri)); + () -> { + try (WSClient connect = new WSClient(vertx).connect(new WebSocketConnectOptions().addSubProtocol("oak"), + endUri)) { + // handshake should fail + } + }); Throwable cause = e.getCause(); assertTrue(cause instanceof WebSocketClientHandshakeException); assertFalse(Endpoint.OPEN_CALLED.get()); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java index 94860bcd0c18f..9cb3a2e8cd16a 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerBuildConfig.java @@ -10,20 +10,28 @@ public interface WebSocketsServerBuildConfig { /** - * Specifies whether to activate the CDI request context when an endpoint callback is invoked. By default, the request - * context is only activated if needed. + * Specifies the activation strategy for the CDI request context during endpoint callback invocation. By default, the + * request context is only activated if needed, i.e. if there is a bean with the given scope, or a bean annotated + * with a security annotation (such as {@code @RolesAllowed}), in the dependency tree of the endpoint. */ @WithDefault("auto") - RequestContextActivation activateRequestContext(); + ContextActivation activateRequestContext(); - enum RequestContextActivation { + /** + * Specifies the activation strategy for the CDI session context during endpoint callback invocation. By default, the + * session context is only activated if needed, i.e. if there is a bean with the given scope in the dependency tree of the + * endpoint. + */ + @WithDefault("auto") + ContextActivation activateSessionContext(); + + enum ContextActivation { /** - * The request context is only activated if needed, i.e. if there is a request scoped bean , or a bean annotated - * with a security annotation (such as {@code @RolesAllowed}) in the dependency tree of the endpoint. + * The context is only activated if needed. */ AUTO, /** - * The request context is always activated. + * The context is always activated. */ ALWAYS } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java index 7b4a605d8ddc1..2b60536dfc45b 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -1,5 +1,7 @@ package io.quarkus.websockets.next.runtime; +import java.util.Objects; + import org.jboss.logging.Logger; import io.quarkus.arc.InjectableContext.ContextState; @@ -9,6 +11,9 @@ import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; +/** + * Per-endpoint CDI context support. + */ public class ContextSupport { private static final Logger LOG = Logger.getLogger(ContextSupport.class); @@ -24,9 +29,9 @@ public class ContextSupport { WebSocketSessionContext sessionContext, ManagedContext requestContext) { this.connection = connection; - this.sessionContextState = sessionContextState; this.sessionContext = sessionContext; this.requestContext = requestContext; + this.sessionContextState = sessionContext != null ? Objects.requireNonNull(sessionContextState) : null; } void start() { @@ -42,8 +47,10 @@ void start(ContextState requestContextState) { } void startSession() { - // Activate the captured session context - sessionContext.activate(sessionContextState); + if (sessionContext != null) { + // Activate the captured session context + sessionContext.activate(sessionContextState); + } } void end(boolean terminateSession) { @@ -63,13 +70,15 @@ void end(boolean terminateRequest, boolean terminateSession) { if (terminateSession) { // OnClose - terminate the session context endSession(); - } else { + } else if (sessionContext != null) { sessionContext.deactivate(); } } void endSession() { - sessionContext.terminate(); + if (sessionContext != null) { + sessionContext.terminate(); + } } static Context createNewDuplicatedContext(Context context, WebSocketConnectionBase connection) { diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index 8f4b8653a338e..2d99f1f133e7e 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -35,17 +35,21 @@ class Endpoints { static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSocketConnectionBase connection, WebSocketBase ws, String generatedEndpointClass, Optional autoPingInterval, SecuritySupport securitySupport, UnhandledFailureStrategy unhandledFailureStrategy, TrafficLogger trafficLogger, - Runnable onClose, boolean activateRequestContext, TelemetrySupport telemetrySupport) { + Runnable onClose, boolean activateRequestContext, boolean activateSessionContext, + TelemetrySupport telemetrySupport) { Context context = vertx.getOrCreateContext(); // Initialize and capture the session context state that will be activated // during message processing - WebSocketSessionContext sessionContext = sessionContext(container); - SessionContextState sessionContextState = sessionContext.initializeContextState(); + WebSocketSessionContext sessionContext = null; + SessionContextState sessionContextState = null; + if (activateSessionContext) { + sessionContext = sessionContext(container); + sessionContextState = sessionContext.initializeContextState(); + } ContextSupport contextSupport = new ContextSupport(connection, sessionContextState, - sessionContext(container), - activateRequestContext ? container.requestContext() : null); + sessionContext, activateRequestContext ? container.requestContext() : null); // Create an endpoint that delegates callbacks to the endpoint bean WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, contextSupport, diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java index 427c9f8abe087..3a7fd337e4768 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -148,7 +148,7 @@ public void handle(AsyncResult r) { () -> { connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); client.get().close(); - }, true, telemetrySupport); + }, true, true, telemetrySupport); return connection; }); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index f738e265fae1a..5d83d17b8550b 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -62,7 +62,7 @@ public Object get() { } public Handler createEndpointHandler(String generatedEndpointClass, String endpointId, - boolean activateRequestContext, String endpointPath) { + boolean activateRequestContext, boolean activateSessionContext, String endpointPath) { ArcContainer container = Arc.container(); ConnectionManager connectionManager = container.instance(ConnectionManager.class).get(); Codecs codecs = container.instance(Codecs.class).get(); @@ -125,7 +125,7 @@ public void handle(Throwable throwable) { Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass, config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(), trafficLogger, () -> connectionManager.remove(generatedEndpointClass, connection), activateRequestContext, - telemetrySupport); + activateSessionContext, telemetrySupport); }); }