Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the capability to aggregate WebSocket continuation frames. #5357

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ final class DefaultWebSocketClient implements WebSocketClient {
private final boolean allowMaskMismatch;
private final List<String> subprotocols;
private final String joinedSubprotocols;
private final boolean aggregateContinuation;

DefaultWebSocketClient(WebClient webClient, int maxFramePayloadLength, boolean allowMaskMismatch,
List<String> subprotocols) {
List<String> subprotocols, boolean aggregateContinuation) {
this.webClient = webClient;
this.maxFramePayloadLength = maxFramePayloadLength;
this.allowMaskMismatch = allowMaskMismatch;
Expand All @@ -78,6 +79,7 @@ final class DefaultWebSocketClient implements WebSocketClient {
} else {
joinedSubprotocols = "";
}
this.aggregateContinuation = aggregateContinuation;
}

@Override
Expand Down Expand Up @@ -115,7 +117,8 @@ public CompletableFuture<WebSocketSession> connect(String path) {
}

final WebSocketClientFrameDecoder decoder =
new WebSocketClientFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch);
new WebSocketClientFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch,
aggregateContinuation);
final WebSocketWrapper inbound = new WebSocketWrapper(split.body().decode(decoder, ctx.alloc()));

result.complete(new WebSocketSession(ctx, responseHeaders, inbound, outboundFuture, encoder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import com.linecorp.armeria.common.auth.BasicToken;
import com.linecorp.armeria.common.auth.OAuth1aToken;
import com.linecorp.armeria.common.auth.OAuth2Token;
import com.linecorp.armeria.common.websocket.WebSocketFrameType;

/**
* Builds a {@link WebSocketClient}.
Expand All @@ -80,6 +81,7 @@ public final class WebSocketClientBuilder extends AbstractWebClientBuilder {
private int maxFramePayloadLength = DEFAULT_MAX_FRAME_PAYLOAD_LENGTH;
private boolean allowMaskMismatch;
private List<String> subprotocols = ImmutableList.of();
private boolean aggregateContinuation;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor question) Maybe it is just me, but I feel like most users won't be interested in continuation frames and true would be the more sensible default.

Copy link
Contributor

@ikhoon ikhoon Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may provide a higher API that can automatically aggregate continuation frames with StreamMessage<String> or StreamMessage<(ByteBuf|byte[])>.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most users won't be interested in continuation frame

I completely agree with your perspective on this matter, but I believe that setting true for aggregateContinuation as a default is a separate consideration.
If continuation frames are used that the server developers do not notice, the server may experience memory pressure. So I prefer to set it false. 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may provide a higher API that can automatically aggregate continuation frames with StreamMessage or StreamMessage<(ByteBuf|byte[])>.

Could you share an example of the usage? I'm not sure how it's going to be used because what we need here is the stream message of WebSocketFrame instead of strings or byte arrays. 🤔

Copy link
Contributor

@ikhoon ikhoon Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Continuation is a low-level detail that many users might not be interested in for their business logic. I think we can abstract a set of WebSocket frames followed by continuation frames with StreamMessage so that users easily aggregate the whole message or subscribe to the messages one by one.

TextStreamMessage extends StreamMessage<String> {
   CompletableFuture<String> aggregate();
}

ByteStreamMessage extends StreamMessage<byte[]> {
   CompletableFuture<byte[]> aggregate();
}

interface WebSocketStreamHandler {
   void onText(RequestContext ctx, TextStreamMessage texts, WebSocketWriter out);
   
   void onBinary(RequestContext ctx, ByteStreamMessage bytes, WebSocketWriter out);
   ...  
}
class EchoWebSocketStreamHandler implements WebSocketStreamHandler {
   void onText(RequestContext ctx, TextStreamMessage texts, WebSocketWriter out) {
     // If users want to get the whole text
     texts.aggregate().thenAccept(text -> {
        out.write(text);
     });
     // To get a text frame one by one with backpressure
     texts.subscribe(...)
   }
   
   void onBinary(RequestContext ctx, ByteStreamMessage bytes, WebSocketWriter out) {
       ...
   }
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it. I think I've seen a similar API from AKKA whether I like it or not. 😆

We can probably provide two different web socket handler APIs later.
The first one is the API you showed above and the second one is the simple callback style that you've used in your k8s PR.
https://github.com/line/armeria/pull/5167/files#diff-6dfa9f45d08fd2befd27967eec61221bcb1489bb0ceef6764d39f93aa0369d2bR138

For the second one, I think using this boolean value is more proper because the first one introduces another layer with a completable future that isn't needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. We don't know which one is preferred by users. Both versions could be provided later if necessary.


WebSocketClientBuilder(URI uri) {
super(validateUri(requireNonNull(uri, "uri")), null, null, null);
Expand Down Expand Up @@ -184,6 +186,17 @@ public WebSocketClientBuilder subprotocols(Iterable<String> subprotocols) {
return this;
}

/**
* Sets whether to aggregate the subsequent continuation frames of the incoming
* {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame into a single
* {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame.
* Note that enabling this feature may lead to increased memory usage, so use it with caution.
*/
public WebSocketClientBuilder aggregateContinuation(boolean aggregateContinuation) {
this.aggregateContinuation = aggregateContinuation;
return this;
}

/**
* Sets whether to add an {@link HttpHeaderNames#ORIGIN} header automatically when sending
* an {@link HttpRequest} when the {@link HttpRequest#headers()} does not have it.
Expand All @@ -200,7 +213,8 @@ public WebSocketClientBuilder autoFillOriginHeader(boolean autoFillOriginHeader)
*/
public WebSocketClient build() {
final WebClient webClient = buildWebClient();
return new DefaultWebSocketClient(webClient, maxFramePayloadLength, allowMaskMismatch, subprotocols);
return new DefaultWebSocketClient(webClient, maxFramePayloadLength, allowMaskMismatch, subprotocols,
aggregateContinuation);
}

// Override the return type of the chaining methods in the superclass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ final class WebSocketClientFrameDecoder extends WebSocketFrameDecoder {
private final ClientRequestContext ctx;

WebSocketClientFrameDecoder(ClientRequestContext ctx, int maxFramePayloadLength,
boolean allowMaskMismatch) {
super(maxFramePayloadLength, allowMaskMismatch);
boolean allowMaskMismatch, boolean aggregateContinuation) {
super(maxFramePayloadLength, allowMaskMismatch, aggregateContinuation);
this.ctx = ctx;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@

package com.linecorp.armeria.internal.common.websocket;

import java.util.ArrayList;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.linecorp.armeria.common.Bytes;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.stream.HttpDecoder;
import com.linecorp.armeria.common.stream.StreamDecoderInput;
Expand Down Expand Up @@ -65,6 +69,9 @@ enum State {

private final int maxFramePayloadLength;
private final boolean allowMaskMismatch;
private final boolean aggregateContinuation;
private final List<WebSocketFrame> aggregatingFrames = new ArrayList<>();
private long aggregatingFramesLength;
@Nullable
private WebSocket outboundFrames;

Expand All @@ -79,9 +86,11 @@ enum State {
private boolean receivedClosingHandshake;
private State state = State.READING_FIRST;

protected WebSocketFrameDecoder(int maxFramePayloadLength, boolean allowMaskMismatch) {
protected WebSocketFrameDecoder(int maxFramePayloadLength, boolean allowMaskMismatch,
boolean aggregateContinuation) {
this.maxFramePayloadLength = maxFramePayloadLength;
this.allowMaskMismatch = allowMaskMismatch;
this.aggregateContinuation = aggregateContinuation;
}

public void setOutboundWebSocket(WebSocket outboundFrames) {
Expand Down Expand Up @@ -251,6 +260,7 @@ public void process(StreamDecoderInput in, StreamDecoderOutput<WebSocketFrame> o
logger.trace("{} is decoded.", decodedFrame);
continue; // to while loop
}

assert payloadBuffer != null;
if (frameOpcode == WebSocketFrameType.PONG.opcode()) {
final WebSocketFrame decodedFrame = WebSocketFrame.ofPooledPong(payloadBuffer);
Expand All @@ -268,25 +278,57 @@ public void process(StreamDecoderInput in, StreamDecoderOutput<WebSocketFrame> o
continue; // to while loop
}

if (frameOpcode != WebSocketFrameType.TEXT.opcode() &&
frameOpcode != WebSocketFrameType.BINARY.opcode() &&
frameOpcode != WebSocketFrameType.CONTINUATION.opcode()) {
throw protocolViolation(WebSocketCloseStatus.INVALID_MESSAGE_TYPE,
"Cannot decode a web socket frame with opcode: " + frameOpcode);
}

final WebSocketFrame decodedFrame;
if (frameOpcode == WebSocketFrameType.TEXT.opcode()) {
decodedFrame = WebSocketFrame.ofPooledText(payloadBuffer, finalFragment);
out.add(decodedFrame);
} else if (frameOpcode == WebSocketFrameType.BINARY.opcode()) {
decodedFrame = WebSocketFrame.ofPooledBinary(payloadBuffer, finalFragment);
out.add(decodedFrame);
} else if (frameOpcode == WebSocketFrameType.CONTINUATION.opcode()) {
decodedFrame = WebSocketFrame.ofPooledContinuation(payloadBuffer, finalFragment);
out.add(decodedFrame);
} else {
throw protocolViolation(WebSocketCloseStatus.INVALID_MESSAGE_TYPE,
"Cannot decode a web socket frame with opcode: " + frameOpcode);
assert frameOpcode == WebSocketFrameType.CONTINUATION.opcode();
decodedFrame = WebSocketFrame.ofPooledContinuation(payloadBuffer, finalFragment);
}
logger.trace("{} is decoded.", decodedFrame);

if (finalFragment) {
fragmentedFramesCount = 0;
ikhoon marked this conversation as resolved.
Show resolved Hide resolved
aggregatingFramesLength = 0;
if (aggregatingFrames.isEmpty()) {
out.add(decodedFrame);
} else {
aggregatingFrames.add(decodedFrame);
ikhoon marked this conversation as resolved.
Show resolved Hide resolved
final ByteBuf[] byteBufs = aggregatingFrames.stream()
.map(Bytes::byteBuf)
.toArray(ByteBuf[]::new);
if (aggregatingFrames.get(0).type() == WebSocketFrameType.TEXT) {
out.add(WebSocketFrame.ofPooledText(Unpooled.wrappedBuffer(byteBufs), true));
} else {
out.add(WebSocketFrame.ofPooledBinary(Unpooled.wrappedBuffer(byteBufs), true));
}
aggregatingFrames.clear();
}
} else {
fragmentedFramesCount++;
if (aggregateContinuation) {
aggregatingFramesLength += framePayloadLength;
aggregatingFrames.add(decodedFrame);
if (aggregatingFramesLength > maxFramePayloadLength) {
// decodedFrame is release in processOnError.
throw protocolViolation(
ikhoon marked this conversation as resolved.
Show resolved Hide resolved
WebSocketCloseStatus.MESSAGE_TOO_BIG,
"The length of aggregated frames exceeded the max frame length. " +
" aggregated length: " + aggregatingFramesLength +
", max frame length: " + maxFramePayloadLength);
}
} else {
out.add(decodedFrame);
}
}
continue; // to while loop
default:
Expand Down Expand Up @@ -361,8 +403,15 @@ private void validateCloseFrame(ByteBuf buffer) {
}
}

@Override
public void processOnComplete(StreamDecoderInput in, StreamDecoderOutput<WebSocketFrame> out)
throws Exception {
cleanup();
}

@Override
public void processOnError(Throwable cause) {
cleanup();
// If an exception from the inbound stream is raised after receiving a close frame,
// we should not abort the outbound stream.
if (!receivedClosingHandshake) {
Expand All @@ -374,4 +423,13 @@ public void processOnError(Throwable cause) {
}

protected void onProcessOnError(Throwable cause) {}

private void cleanup() {
if (!aggregatingFrames.isEmpty()) {
for (WebSocketFrame frame : aggregatingFrames) {
frame.close();
}
aggregatingFrames.clear();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.linecorp.armeria.internal.common.websocket;

import static com.google.common.base.Strings.isNullOrEmpty;
import static io.netty.util.AsciiString.contentEquals;
import static io.netty.util.AsciiString.contentEqualsIgnoreCase;
import static io.netty.util.AsciiString.trim;
Expand All @@ -29,6 +30,7 @@
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.websocket.CloseWebSocketFrame;
import com.linecorp.armeria.common.websocket.WebSocketCloseStatus;
import com.linecorp.armeria.common.websocket.WebSocketFrame;
Expand Down Expand Up @@ -146,12 +148,25 @@ public static CloseWebSocketFrame newCloseWebSocketFrame(Throwable cause) {
} else {
closeStatus = WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
}
String reasonPhrase = cause.getMessage();
// If the length of the phrase exceeds 125 characters, it is truncated to satisfy the
// <a href="https://datatracker.ietf.org/doc/html/rfc6455#section-5.5">specification</a>.
String reasonPhrase = truncate(cause.getMessage());
if (reasonPhrase == null) {
reasonPhrase = closeStatus.reasonPhrase();
}
return WebSocketFrame.ofClose(closeStatus, reasonPhrase);
}

@Nullable
private static String truncate(@Nullable String message) {
if (isNullOrEmpty(message)) {
return null;
}
if (message.length() <= 125) {
return message;
}
return message.substring(0, 111) + "...(truncated)"; // + 14 characters
}

private WebSocketUtil() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,20 @@ public final class DefaultWebSocketService implements WebSocketService, WebSocke
private final Set<String> subprotocols;
private final Set<String> allowedOrigins;
private final boolean allowAnyOrigin;
private final boolean aggregateContinuation;

public DefaultWebSocketService(WebSocketServiceHandler handler, @Nullable HttpService fallbackService,
int maxFramePayloadLength, boolean allowMaskMismatch,
Set<String> subprotocols, Set<String> allowedOrigins,
boolean allowAnyOrigin) {
boolean allowAnyOrigin, boolean aggregateContinuation) {
this.handler = handler;
this.fallbackService = fallbackService;
this.maxFramePayloadLength = maxFramePayloadLength;
this.allowMaskMismatch = allowMaskMismatch;
this.subprotocols = subprotocols;
this.allowedOrigins = allowedOrigins;
this.allowAnyOrigin = allowAnyOrigin;
this.aggregateContinuation = aggregateContinuation;
}

@Override
Expand Down Expand Up @@ -339,7 +341,8 @@ private static HttpResponse checkVersion(RequestHeaders headers) {
@Override
public WebSocket decode(ServiceRequestContext ctx, HttpRequest req) {
final WebSocketServiceFrameDecoder decoder =
new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch);
new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch,
aggregateContinuation);
ctx.setAttr(DECODER, decoder);
return new WebSocketWrapper(req.decode(decoder, ctx.alloc()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ final class WebSocketServiceFrameDecoder extends WebSocketFrameDecoder {
private final ServiceRequestContext ctx;

WebSocketServiceFrameDecoder(ServiceRequestContext ctx, int maxFramePayloadLength,
boolean allowMaskMismatch) {
super(maxFramePayloadLength, allowMaskMismatch);
boolean allowMaskMismatch, boolean aggregateContinuation) {
super(maxFramePayloadLength, allowMaskMismatch, aggregateContinuation);
this.ctx = ctx;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.websocket.WebSocketCloseStatus;
import com.linecorp.armeria.common.websocket.WebSocketFrameType;
import com.linecorp.armeria.internal.common.websocket.WebSocketUtil;
import com.linecorp.armeria.internal.server.websocket.DefaultWebSocketService;
import com.linecorp.armeria.server.HttpService;
Expand Down Expand Up @@ -61,6 +62,7 @@ public final class WebSocketServiceBuilder {
private boolean allowMaskMismatch;
private Set<String> subprotocols = ImmutableSet.of();
private Set<String> allowedOrigins = ImmutableSet.of();
private boolean aggregateContinuation;
@Nullable
private HttpService fallbackService;

Expand Down Expand Up @@ -110,6 +112,19 @@ public WebSocketServiceBuilder subprotocols(Iterable<String> subprotocols) {
return this;
}

/**
* Sets whether to aggregate the subsequent continuation frames of the incoming
* {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame into a single
* {@link WebSocketFrameType#TEXT} or {@link WebSocketFrameType#BINARY} frame.
* If the length of the aggregated frames exceeds the {@link #maxFramePayloadLength(int)},
* a close frame with the status {@link WebSocketCloseStatus#MESSAGE_TOO_BIG} is sent to the peer.
* Note that enabling this feature may lead to increased memory usage, so use it with caution.
*/
public WebSocketServiceBuilder aggregateContinuation(boolean aggregateContinuation) {
this.aggregateContinuation = aggregateContinuation;
return this;
}

/**
* Sets the allowed origins. The same-origin is allowed by default.
* Specify {@value ANY_ORIGIN} to allow any origins.
Expand Down Expand Up @@ -160,6 +175,7 @@ public WebSocketServiceBuilder fallbackService(HttpService fallbackService) {
*/
public WebSocketService build() {
return new DefaultWebSocketService(handler, fallbackService, maxFramePayloadLength, allowMaskMismatch,
subprotocols, allowedOrigins, allowedOrigins.contains(ANY_ORIGIN));
subprotocols, allowedOrigins, allowedOrigins.contains(ANY_ORIGIN),
aggregateContinuation);
}
}
Loading