From 90206bbf52e73f4af329d4b95c8dd0a7a966f827 Mon Sep 17 00:00:00 2001 From: songmw725 Date: Fri, 22 Dec 2023 14:43:56 +0900 Subject: [PATCH 1/4] Introduce the capability to aggregate WebSocket continuation frames. Motivation: Users may find it beneficial to aggregate WebSocket continuation frames and their text or binary frame into a single frame instead of managing them individually. Modifications: - Introduce properties `WebSocketServiceBuilder.aggregateContinuation()` and `WebSocketClientBuilder.aggregateContinuation()` to enable the aggregation of continuation frames. Result: - You now have the option to aggregate WebSocket continuation frames for more streamlined handling. --- .../websocket/DefaultWebSocketClient.java | 7 +- .../websocket/WebSocketClientBuilder.java | 16 +- .../WebSocketClientFrameDecoder.java | 4 +- .../websocket/WebSocketFrameDecoder.java | 66 ++++++-- .../server/websocket/WebSocketService.java | 8 +- .../websocket/WebSocketServiceBuilder.java | 16 +- .../WebSocketServiceFrameDecoder.java | 4 +- ...SocketClientAggregateContinuationTest.java | 142 +++++++++++++++++ .../client/websocket/WebSocketClientTest.java | 47 +----- .../WebSocketInboundTestHandler.java | 73 +++++++++ .../WebSocketFrameEncoderAndDecoderTest.java | 9 +- ...ocketServiceAggregateContinuationTest.java | 147 ++++++++++++++++++ 12 files changed, 467 insertions(+), 72 deletions(-) create mode 100644 core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientAggregateContinuationTest.java create mode 100644 core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketInboundTestHandler.java create mode 100644 core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java b/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java index e3f2eed09e6..4b4fc789495 100644 --- a/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java @@ -66,9 +66,10 @@ final class DefaultWebSocketClient implements WebSocketClient { private final boolean allowMaskMismatch; private final List subprotocols; private final String joinedSubprotocols; + private final boolean aggregateContinuation; DefaultWebSocketClient(WebClient webClient, int maxFramePayloadLength, boolean allowMaskMismatch, - List subprotocols) { + List subprotocols, boolean aggregateContinuation) { this.webClient = webClient; this.maxFramePayloadLength = maxFramePayloadLength; this.allowMaskMismatch = allowMaskMismatch; @@ -78,6 +79,7 @@ final class DefaultWebSocketClient implements WebSocketClient { } else { joinedSubprotocols = ""; } + this.aggregateContinuation = aggregateContinuation; } @Override @@ -115,7 +117,8 @@ public CompletableFuture connect(String path) { } final WebSocketClientFrameDecoder decoder = - new WebSocketClientFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); + new WebSocketClientFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch, + aggregateContinuation); final WebSocketWrapper inbound = new WebSocketWrapper(split.body().decode(decoder, ctx.alloc())); result.complete(new WebSocketSession(ctx, responseHeaders, inbound, outboundFuture, encoder)); diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java index 6a3deec8092..8f895151658 100644 --- a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java @@ -61,6 +61,7 @@ import com.linecorp.armeria.common.auth.BasicToken; import com.linecorp.armeria.common.auth.OAuth1aToken; import com.linecorp.armeria.common.auth.OAuth2Token; +import com.linecorp.armeria.common.websocket.WebSocketFrameType; /** * Builds a {@link WebSocketClient}. @@ -80,6 +81,7 @@ public final class WebSocketClientBuilder extends AbstractWebClientBuilder { private int maxFramePayloadLength = DEFAULT_MAX_FRAME_PAYLOAD_LENGTH; private boolean allowMaskMismatch; private List subprotocols = ImmutableList.of(); + private boolean aggregateContinuation; WebSocketClientBuilder(URI uri) { super(validateUri(requireNonNull(uri, "uri")), null, null, null); @@ -184,6 +186,17 @@ public WebSocketClientBuilder subprotocols(Iterable subprotocols) { return this; } + /** + * Sets whether to aggregate the subsequent continuation frames of the incoming + * {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame into a single + * {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame. + * Note that enabling this feature may lead to increased memory usage, so use it with caution. + */ + public WebSocketClientBuilder aggregateContinuation(boolean aggregateContinuation) { + this.aggregateContinuation = aggregateContinuation; + return this; + } + /** * Sets whether to add an {@link HttpHeaderNames#ORIGIN} header automatically when sending * an {@link HttpRequest} when the {@link HttpRequest#headers()} does not have it. @@ -200,7 +213,8 @@ public WebSocketClientBuilder autoFillOriginHeader(boolean autoFillOriginHeader) */ public WebSocketClient build() { final WebClient webClient = buildWebClient(); - return new DefaultWebSocketClient(webClient, maxFramePayloadLength, allowMaskMismatch, subprotocols); + return new DefaultWebSocketClient(webClient, maxFramePayloadLength, allowMaskMismatch, subprotocols, + aggregateContinuation); } // Override the return type of the chaining methods in the superclass. diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java index 7390953ac5f..fc95c592842 100644 --- a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java @@ -24,8 +24,8 @@ final class WebSocketClientFrameDecoder extends WebSocketFrameDecoder { private final ClientRequestContext ctx; WebSocketClientFrameDecoder(ClientRequestContext ctx, int maxFramePayloadLength, - boolean allowMaskMismatch) { - super(ctx, maxFramePayloadLength, allowMaskMismatch); + boolean allowMaskMismatch, boolean aggregateContinuation) { + super(maxFramePayloadLength, allowMaskMismatch, aggregateContinuation); this.ctx = ctx; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java index ad31efcb993..023c7901c07 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java @@ -31,10 +31,13 @@ package com.linecorp.armeria.internal.common.websocket; +import java.util.ArrayList; +import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.Bytes; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.HttpDecoder; import com.linecorp.armeria.common.stream.StreamDecoderInput; @@ -64,9 +67,10 @@ enum State { CORRUPT } - private final RequestContext ctx; private final int maxFramePayloadLength; private final boolean allowMaskMismatch; + private final boolean aggregateContinuation; + private final List aggregatingFrames = new ArrayList<>(); @Nullable private WebSocket outboundFrames; @@ -81,10 +85,11 @@ enum State { private boolean receivedClosingHandshake; private State state = State.READING_FIRST; - protected WebSocketFrameDecoder(RequestContext ctx, int maxFramePayloadLength, boolean allowMaskMismatch) { - this.ctx = ctx; + protected WebSocketFrameDecoder(int maxFramePayloadLength, boolean allowMaskMismatch, + boolean aggregateContinuation) { this.maxFramePayloadLength = maxFramePayloadLength; this.allowMaskMismatch = allowMaskMismatch; + this.aggregateContinuation = aggregateContinuation; } public void setOutboundWebSocket(WebSocket outboundFrames) { @@ -254,6 +259,7 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o logger.trace("{} is decoded.", decodedFrame); continue; // to while loop } + assert payloadBuffer != null; if (frameOpcode == WebSocketFrameType.PONG.opcode()) { final WebSocketFrame decodedFrame = WebSocketFrame.ofPooledPong(payloadBuffer); @@ -271,25 +277,47 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o continue; // to while loop } + if (frameOpcode != WebSocketFrameType.TEXT.opcode() && + frameOpcode != WebSocketFrameType.BINARY.opcode() && + frameOpcode != WebSocketFrameType.CONTINUATION.opcode()) { + throw protocolViolation(WebSocketCloseStatus.INVALID_MESSAGE_TYPE, + "Cannot decode a web socket frame with opcode: " + frameOpcode); + } + final WebSocketFrame decodedFrame; if (frameOpcode == WebSocketFrameType.TEXT.opcode()) { decodedFrame = WebSocketFrame.ofPooledText(payloadBuffer, finalFragment); - out.add(decodedFrame); } else if (frameOpcode == WebSocketFrameType.BINARY.opcode()) { decodedFrame = WebSocketFrame.ofPooledBinary(payloadBuffer, finalFragment); - out.add(decodedFrame); - } else if (frameOpcode == WebSocketFrameType.CONTINUATION.opcode()) { - decodedFrame = WebSocketFrame.ofPooledContinuation(payloadBuffer, finalFragment); - out.add(decodedFrame); } else { - throw protocolViolation(WebSocketCloseStatus.INVALID_MESSAGE_TYPE, - "Cannot decode a web socket frame with opcode: " + frameOpcode); + assert frameOpcode == WebSocketFrameType.CONTINUATION.opcode(); + decodedFrame = WebSocketFrame.ofPooledContinuation(payloadBuffer, finalFragment); } logger.trace("{} is decoded.", decodedFrame); + if (finalFragment) { fragmentedFramesCount = 0; + if (aggregatingFrames.isEmpty()) { + out.add(decodedFrame); + } else { + aggregatingFrames.add(decodedFrame); + final ByteBuf[] byteBufs = aggregatingFrames.stream() + .map(Bytes::byteBuf) + .toArray(ByteBuf[]::new); + if (aggregatingFrames.get(0).type() == WebSocketFrameType.TEXT) { + out.add(WebSocketFrame.ofPooledText(Unpooled.wrappedBuffer(byteBufs), true)); + } else { + out.add(WebSocketFrame.ofPooledBinary(Unpooled.wrappedBuffer(byteBufs), true)); + } + aggregatingFrames.clear(); + } } else { fragmentedFramesCount++; + if (aggregateContinuation) { + aggregatingFrames.add(decodedFrame); + } else { + out.add(decodedFrame); + } } continue; // to while loop default: @@ -364,8 +392,15 @@ private void validateCloseFrame(ByteBuf buffer) { } } + @Override + public void processOnComplete(StreamDecoderInput in, StreamDecoderOutput out) + throws Exception { + cleanup(); + } + @Override public void processOnError(Throwable cause) { + cleanup(); // If an exception from the inbound stream is raised after receiving a close frame, // we should not abort the outbound stream. if (!receivedClosingHandshake) { @@ -377,4 +412,13 @@ public void processOnError(Throwable cause) { } protected void onProcessOnError(Throwable cause) {} + + private void cleanup() { + if (!aggregatingFrames.isEmpty()) { + for (WebSocketFrame frame : aggregatingFrames) { + frame.close(); + } + aggregatingFrames.clear(); + } + } } diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java index 9c53cf2e553..b819f07bf7e 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java @@ -98,15 +98,18 @@ public static WebSocketServiceBuilder builder(WebSocketServiceHandler handler) { private final Set subprotocols; private final Set allowedOrigins; private final boolean allowAnyOrigin; + private final boolean aggregateContinuation; WebSocketService(WebSocketServiceHandler handler, int maxFramePayloadLength, boolean allowMaskMismatch, - Set subprotocols, Set allowedOrigins, boolean allowAnyOrigin) { + Set subprotocols, Set allowedOrigins, boolean allowAnyOrigin, + boolean aggregateContinuation) { this.handler = handler; this.maxFramePayloadLength = maxFramePayloadLength; this.allowMaskMismatch = allowMaskMismatch; this.subprotocols = subprotocols; this.allowedOrigins = allowedOrigins; this.allowAnyOrigin = allowAnyOrigin; + this.aggregateContinuation = aggregateContinuation; } /** @@ -190,7 +193,8 @@ private void maybeAddSubprotocol(RequestHeaders headers, private HttpResponse handleUpgradeRequest(ServiceRequestContext ctx, HttpRequest req, ResponseHeaders responseHeaders) { final WebSocketServiceFrameDecoder decoder = - new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); + new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch, + aggregateContinuation); final StreamMessage inboundFrames = req.decode(decoder, ctx.alloc()); final WebSocket outboundFrames = handler.handle(ctx, new WebSocketWrapper(inboundFrames)); decoder.setOutboundWebSocket(outboundFrames); diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java index 4003a934257..bbf73d529df 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java @@ -28,6 +28,7 @@ import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; +import com.linecorp.armeria.common.websocket.WebSocketFrameType; import com.linecorp.armeria.internal.common.websocket.WebSocketUtil; import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.ServiceConfig; @@ -59,6 +60,7 @@ public final class WebSocketServiceBuilder { private boolean allowMaskMismatch; private Set subprotocols = ImmutableSet.of(); private Set allowedOrigins = ImmutableSet.of(); + private boolean aggregateContinuation; WebSocketServiceBuilder(WebSocketServiceHandler handler) { this.handler = requireNonNull(handler, "handler"); @@ -106,6 +108,17 @@ public WebSocketServiceBuilder subprotocols(Iterable subprotocols) { return this; } + /** + * Sets whether to aggregate the subsequent continuation frames of the incoming + * {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame into a single + * {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame. + * Note that enabling this feature may lead to increased memory usage, so use it with caution. + */ + public WebSocketServiceBuilder aggregateContinuation(boolean aggregateContinuation) { + this.aggregateContinuation = aggregateContinuation; + return this; + } + /** * Sets the allowed origins. The same-origin is allowed by default. * Specify {@value ANY_ORIGIN} to allow any origins. @@ -147,6 +160,7 @@ private static Set validateOrigins(Iterable allowedOrigins) { */ public WebSocketService build() { return new WebSocketService(handler, maxFramePayloadLength, allowMaskMismatch, - subprotocols, allowedOrigins, allowedOrigins.contains(ANY_ORIGIN)); + subprotocols, allowedOrigins, allowedOrigins.contains(ANY_ORIGIN), + aggregateContinuation); } } diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java index 28bbd71aa2c..c4a50e97f5d 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java @@ -26,8 +26,8 @@ final class WebSocketServiceFrameDecoder extends WebSocketFrameDecoder { private final ServiceRequestContext ctx; WebSocketServiceFrameDecoder(ServiceRequestContext ctx, int maxFramePayloadLength, - boolean allowMaskMismatch) { - super(ctx, maxFramePayloadLength, allowMaskMismatch); + boolean allowMaskMismatch, boolean aggregateContinuation) { + super(maxFramePayloadLength, allowMaskMismatch, aggregateContinuation); this.ctx = ctx; } diff --git a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientAggregateContinuationTest.java b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientAggregateContinuationTest.java new file mode 100644 index 00000000000..8001b1ccdd4 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientAggregateContinuationTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.client.websocket; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketFrameType; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.websocket.WebSocketService; +import com.linecorp.armeria.server.websocket.WebSocketServiceHandler; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class WebSocketClientAggregateContinuationTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.service("/", WebSocketService.of(new WebSocketEchoHandler())); + } + }; + + @CsvSource({ "true", "false" }) + @ParameterizedTest + void aggregateFrames(boolean aggregate) throws InterruptedException { + final WebSocketClient webSocketClient = + WebSocketClient.builder(server.httpUri()) + .aggregateContinuation(aggregate) + .build(); + final WebSocketSession webSocketSession = webSocketClient.connect("/").join(); + + final WebSocketWriter outbound = webSocketSession.outbound(); + outbound.write(WebSocketFrame.ofText("Hello", false)); + outbound.write(WebSocketFrame.ofContinuation(" wor", false)); + outbound.write(WebSocketFrame.ofContinuation("ld!", true)); + + final WebSocketInboundTestHandler inboundHandler = new WebSocketInboundTestHandler( + webSocketSession.inbound(), SessionProtocol.H2C); + + WebSocketFrame frame; + if (aggregate) { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofText("Hello world!")); + } else { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofText("Hello", false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation(" wor", false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation("ld!")); + } + + frame = inboundHandler.inboundQueue().poll(1, TimeUnit.SECONDS); + assertThat(frame).isNull(); + + outbound.write(WebSocketFrame.ofBinary("Hello".getBytes(), false)); + outbound.write(WebSocketFrame.ofContinuation(" wor".getBytes(), false)); + outbound.write(WebSocketFrame.ofContinuation("ld!".getBytes(), true)); + + if (aggregate) { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofBinary("Hello world!".getBytes())); + } else { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofBinary("Hello".getBytes(), false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation(" wor".getBytes(), false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation("ld!".getBytes())); + } + + frame = inboundHandler.inboundQueue().poll(1, TimeUnit.SECONDS); + assertThat(frame).isNull(); + + outbound.close(WebSocketCloseStatus.NORMAL_CLOSURE); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofClose(WebSocketCloseStatus.NORMAL_CLOSURE)); + inboundHandler.completionFuture().join(); + await().until(outbound::isComplete); + } + + static final class WebSocketEchoHandler implements WebSocketServiceHandler { + + @Override + public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { + final WebSocketWriter writer = WebSocket.streaming(); + in.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + if (webSocketFrame.type() != WebSocketFrameType.PING && + webSocketFrame.type() != WebSocketFrameType.PONG) { + writer.write(webSocketFrame); + } + } + + @Override + public void onError(Throwable t) { + writer.close(t); + } + + @Override + public void onComplete() { + writer.close(); + } + }); + return writer; + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java index c98417de6a5..46a680a92c9 100644 --- a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java @@ -19,7 +19,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; -import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -30,7 +29,6 @@ import org.reactivestreams.Subscription; import com.linecorp.armeria.client.ClientFactory; -import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.SerializationFormat; @@ -98,7 +96,7 @@ void webSocketClient(SessionProtocol protocol, boolean defaultClient) throws Int final WebSocketWriter outbound = webSocketSession.outbound(); outbound.write(WebSocketFrame.ofText("hello")); - final WebSocketInboundHandler inboundHandler = new WebSocketInboundHandler( + final WebSocketInboundTestHandler inboundHandler = new WebSocketInboundTestHandler( webSocketSession.inbound(), protocol); WebSocketFrame frame = inboundHandler.inboundQueue().take(); @@ -118,49 +116,6 @@ void webSocketClient(SessionProtocol protocol, boolean defaultClient) throws Int await().until(outbound::isComplete); } - static final class WebSocketInboundHandler { - - private final ArrayBlockingQueue inboundQueue = new ArrayBlockingQueue<>(4); - private final CompletableFuture completionFuture = new CompletableFuture<>(); - - WebSocketInboundHandler(WebSocket inbound, SessionProtocol protocol) { - inbound.subscribe(new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(WebSocketFrame webSocketFrame) { - inboundQueue.add(webSocketFrame); - } - - @Override - public void onError(Throwable t) { - if (protocol.isExplicitHttp1()) { - // After receiving a close frame, ClosedSessionException can be raised for HTTP/1.1 - // before onComplete is called. - assertThat(t).isExactlyInstanceOf(ClosedSessionException.class); - } - completionFuture.complete(null); - } - - @Override - public void onComplete() { - completionFuture.complete(null); - } - }); - } - - ArrayBlockingQueue inboundQueue() { - return inboundQueue; - } - - CompletableFuture completionFuture() { - return completionFuture; - } - } - static final class WebSocketServiceEchoHandler implements WebSocketServiceHandler { @Override diff --git a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketInboundTestHandler.java b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketInboundTestHandler.java new file mode 100644 index 00000000000..67b6c351f98 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketInboundTestHandler.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketFrame; + +public final class WebSocketInboundTestHandler { + + private final ArrayBlockingQueue inboundQueue = new ArrayBlockingQueue<>(4); + private final CompletableFuture completionFuture = new CompletableFuture<>(); + + public WebSocketInboundTestHandler(WebSocket inbound, SessionProtocol protocol) { + inbound.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + inboundQueue.add(webSocketFrame); + } + + @Override + public void onError(Throwable t) { + if (protocol.isExplicitHttp1()) { + // After receiving a close frame, ClosedSessionException can be raised for HTTP/1.1 + // before onComplete is called. + assertThat(t).isExactlyInstanceOf(ClosedSessionException.class); + } + completionFuture.complete(null); + } + + @Override + public void onComplete() { + completionFuture.complete(null); + } + }); + } + + public ArrayBlockingQueue inboundQueue() { + return inboundQueue; + } + + public CompletableFuture completionFuture() { + return completionFuture; + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java index 409e0ebe583..99a0666ae40 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java @@ -52,7 +52,6 @@ import com.linecorp.armeria.common.HttpRequestWriter; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpResponseWriter; -import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.common.websocket.WebSocketFrame; @@ -115,7 +114,7 @@ public void testWebSocketProtocolViolation() throws InterruptedException { final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(true); final HttpRequestWriter requestWriter = HttpRequest.streaming(RequestHeaders.of(HttpMethod.GET, "/")); final WebSocketFrameDecoder decoder = - new TestWebSocketFrameDecoder(ctx, maxPayloadLength, false, true); + new TestWebSocketFrameDecoder(maxPayloadLength, false, true); final CompletableFuture whenComplete = new CompletableFuture<>(); requestWriter.decode(decoder, ctx.alloc()).subscribe(subscriber(whenComplete)); @@ -142,7 +141,7 @@ public void testWebSocketEncodingAndDecoding(boolean maskPayload, boolean allowM final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(maskPayload); final HttpRequestWriter requestWriter = HttpRequest.streaming(RequestHeaders.of(HttpMethod.GET, "/")); final WebSocketFrameDecoder decoder = new TestWebSocketFrameDecoder( - ctx, 1024 * 1024, allowMaskMismatch, maskPayload); + 1024 * 1024, allowMaskMismatch, maskPayload); requestWriter.decode(decoder, ctx.alloc()).subscribe(subscriber(new CompletableFuture<>())); executeTests(encoder, requestWriter); httpResponseWriter.abort(); @@ -235,9 +234,9 @@ private static class TestWebSocketFrameDecoder extends WebSocketFrameDecoder { private final boolean expectMaskedFrames; - TestWebSocketFrameDecoder(RequestContext ctx, int maxFramePayloadLength, + TestWebSocketFrameDecoder(int maxFramePayloadLength, boolean allowMaskMismatch, boolean expectMaskedFrames) { - super(ctx, maxFramePayloadLength, allowMaskMismatch); + super(maxFramePayloadLength, allowMaskMismatch, false); this.expectMaskedFrames = expectMaskedFrames; } diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java new file mode 100644 index 00000000000..175071a0cf8 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.server.websocket; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketInboundTestHandler; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketFrameType; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class WebSocketServiceAggregateContinuationTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + final WebSocketEchoHandler handler = new WebSocketEchoHandler(); + sb.service("/aggregate", WebSocketService.builder(handler) + .aggregateContinuation(true) + .build()); + sb.service("/no_aggregate", WebSocketService.builder(handler) + .aggregateContinuation(false) + .build()); + } + }; + + @CsvSource({ "true", "false" }) + @ParameterizedTest + void aggregateFrames(boolean aggregate) throws InterruptedException { + final WebSocketClient webSocketClient = WebSocketClient.of(server.httpUri()); + final String path = aggregate ? "/aggregate" : "/no_aggregate"; + final WebSocketSession webSocketSession = webSocketClient.connect(path).join(); + + final WebSocketWriter outbound = webSocketSession.outbound(); + outbound.write(WebSocketFrame.ofText("Hello", false)); + outbound.write(WebSocketFrame.ofContinuation(" wor", false)); + outbound.write(WebSocketFrame.ofContinuation("ld!", true)); + + final WebSocketInboundTestHandler inboundHandler = new WebSocketInboundTestHandler( + webSocketSession.inbound(), SessionProtocol.H2C); + + WebSocketFrame frame; + if (aggregate) { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofText("Hello world!")); + } else { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofText("Hello", false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation(" wor", false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation("ld!")); + } + + frame = inboundHandler.inboundQueue().poll(1, TimeUnit.SECONDS); + assertThat(frame).isNull(); + + outbound.write(WebSocketFrame.ofBinary("Hello".getBytes(), false)); + outbound.write(WebSocketFrame.ofContinuation(" wor".getBytes(), false)); + outbound.write(WebSocketFrame.ofContinuation("ld!".getBytes(), true)); + + if (aggregate) { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofBinary("Hello world!".getBytes())); + } else { + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofBinary("Hello".getBytes(), false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation(" wor".getBytes(), false)); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofContinuation("ld!".getBytes())); + } + + frame = inboundHandler.inboundQueue().poll(1, TimeUnit.SECONDS); + assertThat(frame).isNull(); + + outbound.close(WebSocketCloseStatus.NORMAL_CLOSURE); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofClose(WebSocketCloseStatus.NORMAL_CLOSURE)); + inboundHandler.completionFuture().join(); + await().until(outbound::isComplete); + } + + static final class WebSocketEchoHandler implements WebSocketServiceHandler { + + @Override + public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { + final WebSocketWriter writer = WebSocket.streaming(); + in.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + if (webSocketFrame.type() != WebSocketFrameType.PING && + webSocketFrame.type() != WebSocketFrameType.PONG) { + writer.write(webSocketFrame); + } + } + + @Override + public void onError(Throwable t) { + writer.close(t); + } + + @Override + public void onComplete() { + writer.close(); + } + }); + return writer; + } + } +} From 5cc2172623b2c29ff0cb9caa3b92eb668eb41b19 Mon Sep 17 00:00:00 2001 From: songmw725 Date: Thu, 28 Dec 2023 11:52:55 +0900 Subject: [PATCH 2/4] Address comments from @ikhoon --- .../websocket/WebSocketFrameDecoder.java | 12 ++++++++++ .../websocket/WebSocketServiceBuilder.java | 2 ++ ...ocketServiceAggregateContinuationTest.java | 24 ++++++++++++++++++- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java index 023c7901c07..9cd30bf8498 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java @@ -71,6 +71,7 @@ enum State { private final boolean allowMaskMismatch; private final boolean aggregateContinuation; private final List aggregatingFrames = new ArrayList<>(); + private long aggregatingFramesLength; @Nullable private WebSocket outboundFrames; @@ -297,6 +298,7 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o if (finalFragment) { fragmentedFramesCount = 0; + aggregatingFramesLength = 0; if (aggregatingFrames.isEmpty()) { out.add(decodedFrame); } else { @@ -314,6 +316,16 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o } else { fragmentedFramesCount++; if (aggregateContinuation) { + aggregatingFramesLength += framePayloadLength; + if (aggregatingFramesLength > maxFramePayloadLength) { + throw protocolViolation( + WebSocketCloseStatus.MESSAGE_TOO_BIG, + // The message must not exceed 125 length: + // https://datatracker.ietf.org/doc/html/rfc6455/#section-5.5 + "The length of aggregated frames exceeded the max frame length. " + + " aggregated length: " + aggregatingFramesLength + + ", max frame length: " + maxFramePayloadLength); + } aggregatingFrames.add(decodedFrame); } else { out.add(decodedFrame); diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java index bbf73d529df..c39dee1aedd 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java @@ -112,6 +112,8 @@ public WebSocketServiceBuilder subprotocols(Iterable subprotocols) { * Sets whether to aggregate the subsequent continuation frames of the incoming * {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame into a single * {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame. + * If the length of the aggregated frames exceeds the {@link #maxFramePayloadLength(int)}, + * a close frame with the status {@link WebSocketCloseStatus#MESSAGE_TOO_BIG} is sent to the peer. * Note that enabling this feature may lead to increased memory usage, so use it with caution. */ public WebSocketServiceBuilder aggregateContinuation(boolean aggregateContinuation) { diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java index 175071a0cf8..f7dba05ce4f 100644 --- a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceAggregateContinuationTest.java @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -30,6 +31,7 @@ import com.linecorp.armeria.client.websocket.WebSocketInboundTestHandler; import com.linecorp.armeria.client.websocket.WebSocketSession; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.websocket.CloseWebSocketFrame; import com.linecorp.armeria.common.websocket.WebSocket; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.common.websocket.WebSocketFrame; @@ -48,6 +50,7 @@ protected void configure(ServerBuilder sb) throws Exception { final WebSocketEchoHandler handler = new WebSocketEchoHandler(); sb.service("/aggregate", WebSocketService.builder(handler) .aggregateContinuation(true) + .maxFramePayloadLength(13) .build()); sb.service("/no_aggregate", WebSocketService.builder(handler) .aggregateContinuation(false) @@ -112,7 +115,26 @@ void aggregateFrames(boolean aggregate) throws InterruptedException { await().until(outbound::isComplete); } - static final class WebSocketEchoHandler implements WebSocketServiceHandler { + @Test + void aggregateFramesExceedMaxLength() throws InterruptedException { + final WebSocketClient webSocketClient = WebSocketClient.of(server.httpUri()); + final WebSocketSession webSocketSession = webSocketClient.connect("/aggregate").join(); + + final WebSocketWriter outbound = webSocketSession.outbound(); + outbound.write(WebSocketFrame.ofText("Hello", false)); + outbound.write(WebSocketFrame.ofContinuation(" wor", false)); + outbound.write(WebSocketFrame.ofContinuation("ld!", false)); + outbound.write(WebSocketFrame.ofContinuation("!!", false)); + + final WebSocketInboundTestHandler inboundHandler = new WebSocketInboundTestHandler( + webSocketSession.inbound(), SessionProtocol.H2C); + + final WebSocketFrame frame = inboundHandler.inboundQueue().take(); + assertThat(frame.type()).isSameAs(WebSocketFrameType.CLOSE); + assertThat(((CloseWebSocketFrame) frame).status()).isSameAs(WebSocketCloseStatus.MESSAGE_TOO_BIG); + } + + static final class WebSocketEchoHandler implements WebSocketServiceHandler { @Override public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { From 5ae09eadddf0b401618b030196d2651901a179e0 Mon Sep 17 00:00:00 2001 From: songmw725 Date: Tue, 9 Jan 2024 17:23:24 +0900 Subject: [PATCH 3/4] Address the comment from @jrhee17 --- .../websocket/WebSocketFrameDecoder.java | 2 - .../common/websocket/WebSocketUtil.java | 17 ++++++++- .../common/websocket/WebSocketUtilTest.java | 38 +++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtilTest.java diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java index 9cd30bf8498..b9a1a25e901 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java @@ -320,8 +320,6 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o if (aggregatingFramesLength > maxFramePayloadLength) { throw protocolViolation( WebSocketCloseStatus.MESSAGE_TOO_BIG, - // The message must not exceed 125 length: - // https://datatracker.ietf.org/doc/html/rfc6455/#section-5.5 "The length of aggregated frames exceeded the max frame length. " + " aggregated length: " + aggregatingFramesLength + ", max frame length: " + maxFramePayloadLength); diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java index 3b42a0ed504..6c46f50491d 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java @@ -15,6 +15,7 @@ */ package com.linecorp.armeria.internal.common.websocket; +import static com.google.common.base.Strings.isNullOrEmpty; import static io.netty.util.AsciiString.contentEquals; import static io.netty.util.AsciiString.contentEqualsIgnoreCase; import static io.netty.util.AsciiString.trim; @@ -29,6 +30,7 @@ import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.websocket.CloseWebSocketFrame; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.common.websocket.WebSocketFrame; @@ -146,12 +148,25 @@ public static CloseWebSocketFrame newCloseWebSocketFrame(Throwable cause) { } else { closeStatus = WebSocketCloseStatus.INTERNAL_SERVER_ERROR; } - String reasonPhrase = cause.getMessage(); + // If the length of the phrase exceeds 125 characters, it is truncated to satisfy the + // specification. + String reasonPhrase = truncate(cause.getMessage()); if (reasonPhrase == null) { reasonPhrase = closeStatus.reasonPhrase(); } return WebSocketFrame.ofClose(closeStatus, reasonPhrase); } + @Nullable + private static String truncate(@Nullable String message) { + if (isNullOrEmpty(message)) { + return null; + } + if (message.length() <= 125) { + return message; + } + return message.substring(0, 111) + "...(truncated)"; // + 14 characters + } + private WebSocketUtil() {} } diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtilTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtilTest.java new file mode 100644 index 00000000000..88b64ee13d2 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtilTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.internal.common.websocket; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.common.websocket.CloseWebSocketFrame; +import com.linecorp.armeria.server.websocket.WebSocketProtocolViolationException; + +import joptsimple.internal.Strings; + +class WebSocketUtilTest { + + @Test + void reasonPhraseTruncate() { + final String reasonPhrase = Strings.repeat('a', 126); + final WebSocketProtocolViolationException exception = + new WebSocketProtocolViolationException(reasonPhrase); + final CloseWebSocketFrame closeWebSocketFrame = WebSocketUtil.newCloseWebSocketFrame(exception); + assertThat(closeWebSocketFrame.reasonPhrase()).isEqualTo( + reasonPhrase.substring(0, 111) + "...(truncated)"); + } +} From 8955f259b362f311e30ef65597017c23ed3b267f Mon Sep 17 00:00:00 2001 From: songmw725 Date: Fri, 19 Jan 2024 15:06:27 +0900 Subject: [PATCH 4/4] Address the comment from @ikhoon --- .../internal/common/websocket/WebSocketFrameDecoder.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java index b9a1a25e901..1c1f8fd8643 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java @@ -317,14 +317,15 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o fragmentedFramesCount++; if (aggregateContinuation) { aggregatingFramesLength += framePayloadLength; + aggregatingFrames.add(decodedFrame); if (aggregatingFramesLength > maxFramePayloadLength) { + // decodedFrame is release in processOnError. throw protocolViolation( WebSocketCloseStatus.MESSAGE_TOO_BIG, "The length of aggregated frames exceeded the max frame length. " + " aggregated length: " + aggregatingFramesLength + ", max frame length: " + maxFramePayloadLength); } - aggregatingFrames.add(decodedFrame); } else { out.add(decodedFrame); }