Skip to content

Commit

Permalink
Add WebSocketClient (#4972)
Browse files Browse the repository at this point in the history
Motivation
It would be nice if we also support WebSocket clients.

Modifications
- Add `SerializationFormat.WS` for WebSocket.
- Add `WebSocketClient` and its builder.
- Extract the common part of `HttpRequestSubciber` to
`AbsractHttpRequestSubciber`
  and add `WebSocketHttp1RequestSubscriber`
- Add `ClientOpions.AUTO_FILL_ORIGIN_HEADER` for adding the header
automatically.
- Add the pipeline Channel handler for WebSocket

Result
- You can now send and receive WebSocket frames using `WebSocketClient`.

To-Do
- Add `WebSocketClientEventHandler`
- A lot of todos that are in this PR
  • Loading branch information
minwoox authored Aug 22, 2023
1 parent 0c0fb4f commit dc2bdc6
Show file tree
Hide file tree
Showing 78 changed files with 3,517 additions and 801 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.linecorp.armeria.client.HttpResponseDecoder.HttpResponseWrapper;
import com.linecorp.armeria.common.ClosedSessionException;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
Expand Down Expand Up @@ -59,6 +58,7 @@ abstract class AbstractHttpRequestHandler implements ChannelFutureListener {

enum State {
NEEDS_TO_WRITE_FIRST_HEADER,
NEEDS_DATA,
NEEDS_DATA_OR_TRAILERS,
DONE
}
Expand All @@ -71,6 +71,8 @@ enum State {
private final RequestLogBuilder logBuilder;
private final long timeoutMillis;
private final boolean headersOnly;
private final boolean allowTrailers;
private final boolean keepAlive;

// session, id and responseWrapper are assigned in tryInitialize()
@Nullable
Expand All @@ -86,7 +88,8 @@ enum State {

AbstractHttpRequestHandler(Channel ch, ClientHttpObjectEncoder encoder, HttpResponseDecoder responseDecoder,
DecodedHttpResponse originalRes,
ClientRequestContext ctx, long timeoutMillis, boolean headersOnly) {
ClientRequestContext ctx, long timeoutMillis, boolean headersOnly,
boolean allowTrailers, boolean keepAlive) {
this.ch = ch;
this.encoder = encoder;
this.responseDecoder = responseDecoder;
Expand All @@ -95,6 +98,8 @@ enum State {
logBuilder = ctx.logBuilder();
this.timeoutMillis = timeoutMillis;
this.headersOnly = headersOnly;
this.allowTrailers = allowTrailers;
this.keepAlive = keepAlive;
}

abstract void onWriteSuccess();
Expand Down Expand Up @@ -169,7 +174,7 @@ final boolean tryInitialize() {
}

this.session = session;
addResponseToDecoder();
responseWrapper = responseDecoder.addResponse(id, originalRes, ctx, ch.eventLoop());

if (timeoutMillis > 0) {
// The timer would be executed if the first message has not been sent out within the timeout.
Expand All @@ -180,13 +185,6 @@ final boolean tryInitialize() {
return true;
}

private void addResponseToDecoder() {
final long responseTimeoutMillis = ctx.responseTimeoutMillis();
final long maxContentLength = ctx.maxResponseLength();
responseWrapper = responseDecoder.addResponse(id, originalRes, ctx,
ch.eventLoop(), responseTimeoutMillis, maxContentLength);
}

/**
* Writes the {@link RequestHeaders} to the {@link Channel}.
* The {@link RequestHeaders} is merged with {@link ClientRequestContext#additionalRequestHeaders()}
Expand All @@ -199,8 +197,10 @@ final void writeHeaders(RequestHeaders headers) {
assert protocol != null;
if (headersOnly) {
state = State.DONE;
} else {
} else if (allowTrailers) {
state = State.NEEDS_DATA_OR_TRAILERS;
} else {
state = State.NEEDS_DATA;
}

final HttpHeaders internalHeaders;
Expand All @@ -215,7 +215,7 @@ final void writeHeaders(RequestHeaders headers) {
logBuilder.requestHeaders(merged);

final String connectionOption = headers.get(HttpHeaderNames.CONNECTION);
if (CLOSE_STRING.equalsIgnoreCase(connectionOption)) {
if (CLOSE_STRING.equalsIgnoreCase(connectionOption) || !keepAlive) {
// Make the session unhealthy so that subsequent requests do not use it.
// In HTTP/2 request, the "Connection: close" is just interpreted as a signal to close the
// connection by sending a GOAWAY frame that will be sent after receiving the corresponding
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.client;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpObject;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.client.DecodedHttpResponse;

import io.netty.channel.Channel;

abstract class AbstractHttpRequestSubscriber extends AbstractHttpRequestHandler
implements Subscriber<HttpObject> {

private static final HttpData EMPTY_EOS = HttpData.empty().withEndOfStream();

static AbstractHttpRequestSubscriber of(Channel channel, ClientHttpObjectEncoder requestEncoder,
HttpResponseDecoder responseDecoder, SessionProtocol protocol,
ClientRequestContext ctx, HttpRequest req,
DecodedHttpResponse res, long writeTimeoutMillis,
boolean webSocket) {
if (webSocket) {
if (protocol.isExplicitHttp1()) {
return new WebSocketHttp1RequestSubscriber(
channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis);
}
assert protocol.isExplicitHttp2();
return new WebSocketHttp2RequestSubscriber(
channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis);
}
return new HttpRequestSubscriber(
channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis);
}

private final HttpRequest request;

@Nullable
private Subscription subscription;
private boolean isSubscriptionCompleted;

AbstractHttpRequestSubscriber(Channel ch, ClientHttpObjectEncoder encoder,
HttpResponseDecoder responseDecoder,
HttpRequest request, DecodedHttpResponse originalRes,
ClientRequestContext ctx, long timeoutMillis, boolean allowTrailers,
boolean keepAlive) {
super(ch, encoder, responseDecoder, originalRes, ctx, timeoutMillis, request.isEmpty(), allowTrailers,
keepAlive);
this.request = request;
}

@Override
public void onSubscribe(Subscription subscription) {
assert this.subscription == null;
this.subscription = subscription;
if (state() == State.DONE) {
cancel();
return;
}

if (!tryInitialize()) {
return;
}

// NB: This must be invoked at the end of this method because otherwise the callback methods in this
// class can be called before the member fields (subscription, id, responseWrapper and
// timeoutFuture) are initialized.
// It is because the successful write of the first headers will trigger subscription.request(1).
writeHeaders(mapHeaders(request.headers()));
channel().flush();
}

RequestHeaders mapHeaders(RequestHeaders headers) {
return headers;
}

@Override
public void onError(Throwable cause) {
isSubscriptionCompleted = true;
failRequest(cause);
}

@Override
public void onComplete() {
isSubscriptionCompleted = true;

if (state() != State.DONE) {
writeData(EMPTY_EOS);
channel().flush();
}
}

@Override
void onWriteSuccess() {
// Request more messages regardless whether the state is DONE. It makes the producer have
// a chance to produce the last call such as 'onComplete' and 'onError' when there are
// no more messages it can produce.
if (!isSubscriptionCompleted) {
assert subscription != null;
subscription.request(1);
}
}

@Override
void cancel() {
isSubscriptionCompleted = true;
assert subscription != null;
subscription.cancel();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright 2016 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.client;

import java.util.Iterator;

import com.linecorp.armeria.common.ContentTooLargeException;
import com.linecorp.armeria.common.ContentTooLargeExceptionBuilder;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.client.DecodedHttpResponse;
import com.linecorp.armeria.internal.client.HttpSession;
import com.linecorp.armeria.internal.common.InboundTrafficController;
import com.linecorp.armeria.internal.common.KeepAliveHandler;

import io.netty.channel.Channel;
import io.netty.channel.EventLoop;
import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap;

abstract class AbstractHttpResponseDecoder implements HttpResponseDecoder {

private final IntObjectMap<HttpResponseWrapper> responses = new IntObjectHashMap<>();
private final Channel channel;
private final InboundTrafficController inboundTrafficController;

@Nullable
private HttpSession httpSession;

private int unfinishedResponses;
private boolean closing;

AbstractHttpResponseDecoder(Channel channel, InboundTrafficController inboundTrafficController) {
this.channel = channel;
this.inboundTrafficController = inboundTrafficController;
}

@Override
public Channel channel() {
return channel;
}

@Override
public InboundTrafficController inboundTrafficController() {
return inboundTrafficController;
}

@Override
public HttpResponseWrapper addResponse(
int id, DecodedHttpResponse res, ClientRequestContext ctx, EventLoop eventLoop) {
final HttpResponseWrapper newRes =
new HttpResponseWrapper(res, eventLoop, ctx,
ctx.responseTimeoutMillis(), ctx.maxResponseLength());
final HttpResponseWrapper oldRes = responses.put(id, newRes);
final KeepAliveHandler keepAliveHandler = keepAliveHandler();
if (keepAliveHandler != null) {
keepAliveHandler.increaseNumRequests();
}

assert oldRes == null : "addResponse(" + id + ", " + res + ", " + ctx + "): " + oldRes;
onResponseAdded(id, eventLoop, newRes);
return newRes;
}

abstract void onResponseAdded(int id, EventLoop eventLoop, HttpResponseWrapper responseWrapper);

@Nullable
@Override
public HttpResponseWrapper getResponse(int id) {
return responses.get(id);
}

@Nullable
@Override
public HttpResponseWrapper removeResponse(int id) {
if (closing) {
// `unfinishedResponses` will be removed by `failUnfinishedResponses()`
return null;
}

final HttpResponseWrapper removed = responses.remove(id);
if (removed != null) {
unfinishedResponses--;
assert unfinishedResponses >= 0 : unfinishedResponses;
}
return removed;
}

@Override
public boolean hasUnfinishedResponses() {
return unfinishedResponses != 0;
}

@Override
public boolean reserveUnfinishedResponse(int maxUnfinishedResponses) {
if (unfinishedResponses >= maxUnfinishedResponses) {
return false;
}

unfinishedResponses++;
return true;
}

@Override
public void decrementUnfinishedResponses() {
unfinishedResponses--;
}

@Override
public void failUnfinishedResponses(Throwable cause) {
if (closing) {
return;
}
closing = true;

for (final Iterator<HttpResponseWrapper> iterator = responses.values().iterator();
iterator.hasNext();) {
final HttpResponseWrapper res = iterator.next();
// To avoid calling removeResponse by res.close(cause), remove before closing.
iterator.remove();
unfinishedResponses--;
res.close(cause);
}
}

@Override
public HttpSession session() {
if (httpSession != null) {
return httpSession;
}
return httpSession = HttpSession.get(channel);
}

@Override
public boolean needsToDisconnectNow() {
return !session().isAcquirable() && !hasUnfinishedResponses();
}

static ContentTooLargeException contentTooLargeException(HttpResponseWrapper res, long transferred) {
final ContentTooLargeExceptionBuilder builder =
ContentTooLargeException.builder()
.maxContentLength(res.maxContentLength())
.transferred(transferred);
if (res.contentLengthHeaderValue() >= 0) {
builder.contentLength(res.contentLengthHeaderValue());
}
return builder.build();
}
}
Loading

0 comments on commit dc2bdc6

Please sign in to comment.