Skip to content

Commit

Permalink
Very basic header validator
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Nied <[email protected]>
  • Loading branch information
peternied committed Sep 28, 2023
1 parent 104c512 commit b916df4
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@
import org.opensearch.test.OpenSearchIntegTestCase.Scope;

import java.util.Collection;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.IntStream;

import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCounted;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasSize;

import io.netty.handler.codec.http2.HttpConversionUtil;
import static io.netty.handler.codec.http.HttpHeaderNames.HOST;

@ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1)
public class Netty4Http2IT extends OpenSearchNetty4IntegTestCase {

Expand Down Expand Up @@ -56,6 +64,30 @@ public void testThatNettyHttpServerSupportsHttp2GetUpgrades() throws Exception {
}
}


public void testThatNettyHttpServerHandlesAuthenticateCheck() throws Exception {
HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class);
TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses();
TransportAddress transportAddress = randomFrom(boundAddresses);

final FullHttpRequest unauthorizedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
unauthorizedRequest.headers().add("blockme", "Not Allowed");
unauthorizedRequest.headers().add(HOST, "localhost");
unauthorizedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http");


final List<FullHttpResponse> responses = new ArrayList<>();
try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2() ) {
try {
FullHttpResponse unauthorizedResponse = nettyHttpClient.send(transportAddress.address(), unauthorizedRequest);
responses.add(unauthorizedResponse);
assertThat(unauthorizedResponse.status().code(), equalTo(401));
} finally {
responses.forEach(ReferenceCounted::release);
}
}
}

public void testThatNettyHttpServerSupportsHttp2PostUpgrades() throws Exception {
final List<Tuple<String, CharSequence>> requests = List.of(Tuple.tuple("/_search", "{\"query\":{ \"match_all\":{}}}"));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.http.netty4;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.ReferenceCountUtil;

@ChannelHandler.Sharable
public class Netty4Authorizer extends ChannelInboundHandlerAdapter {

final static Logger log = LogManager.getLogger(Netty4Authorizer.class);

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (!(msg instanceof HttpRequest)) {
ctx.fireChannelRead(msg);
}

HttpRequest request = (HttpRequest) msg;
if (!isAuthenticated(request)) {
final FullHttpResponse response = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.UNAUTHORIZED);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
ReferenceCountUtil.release(msg);
} else {
// Lets the request pass to the next channel handler
ctx.fireChannelRead(msg);
}
}

private boolean isAuthenticated(HttpRequest request) {
log.info("Checking if request is authenticated:\n" + request);

final boolean shouldBlock = request.headers().contains("blockme");

return !shouldBlock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,9 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined
@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest httpRequest) {
final Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
final RestResponse earlyResponse = ctx.channel().attr(Netty4HttpServerTransport.EARLY_RESPONSE).get();
final ThreadContext.StoredContext contextToRestore = ctx.channel().attr(Netty4HttpServerTransport.CONTEXT_TO_RESTORE).get();
final RestHandlerContext requestContext = new RestHandlerContext(earlyResponse, contextToRestore);
boolean success = false;
try {
serverTransport.incomingRequest(httpRequest, channel, requestContext);
serverTransport.incomingRequest(httpRequest, channel);
success = true;
} finally {
if (success == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@

package org.opensearch.http.netty4;

import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CHUNK_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_ALIVE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_IDLE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_INTERVAL;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_NO_DELAY;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_RECEIVE_BUFFER_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_REUSE_ADDRESS;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_SEND_BUFFER_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS;

import java.net.InetSocketAddress;
import java.net.SocketOption;
import java.util.concurrent.TimeUnit;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
Expand Down Expand Up @@ -61,10 +80,6 @@
import org.opensearch.transport.SharedGroupFactory;
import org.opensearch.transport.netty4.Netty4Utils;

import java.net.InetSocketAddress;
import java.net.SocketOption;
import java.util.concurrent.TimeUnit;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
Expand Down Expand Up @@ -103,21 +118,6 @@
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;

import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CHUNK_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_ALIVE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_IDLE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_KEEP_INTERVAL;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_NO_DELAY;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_RECEIVE_BUFFER_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_REUSE_ADDRESS;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_TCP_SEND_BUFFER_SIZE;
import static org.opensearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS;

public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
private static final Logger logger = LogManager.getLogger(Netty4HttpServerTransport.class);

Expand Down Expand Up @@ -341,11 +341,6 @@ public ChannelHandler configureServerChannelHandler() {
"opensearch-http-server-channel"
);

public static final AttributeKey<RestResponse> EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response");
public static final AttributeKey<ThreadContext.StoredContext> CONTEXT_TO_RESTORE = AttributeKey.newInstance(
"opensearch-http-request-thread-context"
);

protected static class HttpChannelHandler extends ChannelInitializer<Channel> {

private final Netty4HttpServerTransport transport;
Expand Down Expand Up @@ -427,10 +422,10 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E
final ChannelPipeline pipeline = ctx.pipeline();
pipeline.addAfter(ctx.name(), "handler", getRequestHandler());
pipeline.replace(this, "header_verifier", transport.createHeaderVerifier());
pipeline.addAfter("header_verifier", "decompress", transport.createDecompressor());
pipeline.addAfter("decompress", "aggregator", aggregator);
pipeline.addAfter("header_verifier", "decoder_compress", new HttpContentDecompressor());
pipeline.addAfter("decoder_compress", "aggregator", aggregator);
if (handlingSettings.isCompression()) {
pipeline.addAfter("aggregator", "compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
pipeline.addAfter("aggregator", "encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
}
pipeline.addBefore("handler", "request_creator", requestCreator);
pipeline.addBefore("handler", "response_creator", responseCreator);
Expand All @@ -450,13 +445,13 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) {
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
pipeline.addLast("decoder", decoder);
pipeline.addLast("header_verifier", transport.createHeaderVerifier());
pipeline.addLast("decompress", transport.createDecompressor());
pipeline.addLast("decoder_compress", new HttpContentDecompressor());
pipeline.addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
pipeline.addLast("aggregator", aggregator);
if (handlingSettings.isCompression()) {
pipeline.addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
pipeline.addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
}
pipeline.addLast("request_creator", requestCreator);
pipeline.addLast("response_creator", responseCreator);
Expand Down Expand Up @@ -497,10 +492,10 @@ protected void initChannel(Channel childChannel) throws Exception {
.addLast("byte_buf_sizer", byteBufSizer)
.addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS))
.addLast("header_verifier", transport.createHeaderVerifier())
.addLast("decompress", transport.createDecompressor());
.addLast("decoder_decompress", new HttpContentDecompressor());

if (handlingSettings.isCompression()) {
childChannel.pipeline().addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
childChannel.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
}

childChannel.pipeline()
Expand Down Expand Up @@ -535,12 +530,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}

protected HttpContentDecompressor createDecompressor() {
return new HttpContentDecompressor();
}

protected ChannelInboundHandlerAdapter createHeaderVerifier() {
return new Netty4Authorizer();
// pass-through
return new ChannelInboundHandlerAdapter();
// return new ChannelInboundHandlerAdapter();
}
}
Loading

0 comments on commit b916df4

Please sign in to comment.