From b916df42d83dc520850075965e746b1cddee9682 Mon Sep 17 00:00:00 2001 From: Peter Nied Date: Thu, 28 Sep 2023 23:25:31 +0000 Subject: [PATCH] Very basic header validator Signed-off-by: Peter Nied --- .../opensearch/http/netty4/Netty4Http2IT.java | 32 +++++ .../http/netty4/Netty4Authorizer.java | 56 +++++++++ .../http/netty4/Netty4HttpRequestHandler.java | 5 +- .../netty4/Netty4HttpServerTransport.java | 64 +++++----- .../http/AbstractHttpServerTransport.java | 109 +++--------------- 5 files changed, 130 insertions(+), 136 deletions(-) create mode 100644 modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java index eba2c5ce1e094..44a186393adae 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/opensearch/http/netty4/Netty4Http2IT.java @@ -16,18 +16,26 @@ import org.opensearch.test.OpenSearchIntegTestCase.Scope; import java.util.Collection; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.stream.IntStream; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; import io.netty.util.ReferenceCounted; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasSize; +import io.netty.handler.codec.http2.HttpConversionUtil; +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; + @ClusterScope(scope = Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1) public class Netty4Http2IT extends OpenSearchNetty4IntegTestCase { @@ -56,6 +64,30 @@ public void testThatNettyHttpServerSupportsHttp2GetUpgrades() throws Exception { } } + + public void testThatNettyHttpServerHandlesAuthenticateCheck() throws Exception { + HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); + TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); + TransportAddress transportAddress = randomFrom(boundAddresses); + + final FullHttpRequest unauthorizedRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + unauthorizedRequest.headers().add("blockme", "Not Allowed"); + unauthorizedRequest.headers().add(HOST, "localhost"); + unauthorizedRequest.headers().add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + + + final List responses = new ArrayList<>(); + try (Netty4HttpClient nettyHttpClient = Netty4HttpClient.http2() ) { + try { + FullHttpResponse unauthorizedResponse = nettyHttpClient.send(transportAddress.address(), unauthorizedRequest); + responses.add(unauthorizedResponse); + assertThat(unauthorizedResponse.status().code(), equalTo(401)); + } finally { + responses.forEach(ReferenceCounted::release); + } + } + } + public void testThatNettyHttpServerSupportsHttp2PostUpgrades() throws Exception { final List> requests = List.of(Tuple.tuple("/_search", "{\"query\":{ \"match_all\":{}}}")); diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java new file mode 100644 index 0000000000000..07c0035f5142b --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4Authorizer.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.http.netty4; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelFutureListener; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.ReferenceCountUtil; + +@ChannelHandler.Sharable +public class Netty4Authorizer extends ChannelInboundHandlerAdapter { + + final static Logger log = LogManager.getLogger(Netty4Authorizer.class); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (!(msg instanceof HttpRequest)) { + ctx.fireChannelRead(msg); + } + + HttpRequest request = (HttpRequest) msg; + if (!isAuthenticated(request)) { + final FullHttpResponse response = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.UNAUTHORIZED); + ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + ReferenceCountUtil.release(msg); + } else { + // Lets the request pass to the next channel handler + ctx.fireChannelRead(msg); + } + } + + private boolean isAuthenticated(HttpRequest request) { + log.info("Checking if request is authenticated:\n" + request); + + final boolean shouldBlock = request.headers().contains("blockme"); + + return !shouldBlock; + } +} diff --git a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java index cc331f1b4fd2d..d140eb6f0d1eb 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/http/netty4/Netty4HttpRequestHandler.java @@ -54,12 +54,9 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response"); - public static final AttributeKey CONTEXT_TO_RESTORE = AttributeKey.newInstance( - "opensearch-http-request-thread-context" - ); - protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; @@ -427,10 +422,10 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpMessage msg) throws E final ChannelPipeline pipeline = ctx.pipeline(); pipeline.addAfter(ctx.name(), "handler", getRequestHandler()); pipeline.replace(this, "header_verifier", transport.createHeaderVerifier()); - pipeline.addAfter("header_verifier", "decompress", transport.createDecompressor()); - pipeline.addAfter("decompress", "aggregator", aggregator); + pipeline.addAfter("header_verifier", "decoder_compress", new HttpContentDecompressor()); + pipeline.addAfter("decoder_compress", "aggregator", aggregator); if (handlingSettings.isCompression()) { - pipeline.addAfter("aggregator", "compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + pipeline.addAfter("aggregator", "encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } pipeline.addBefore("handler", "request_creator", requestCreator); pipeline.addBefore("handler", "response_creator", responseCreator); @@ -450,13 +445,13 @@ protected void configureDefaultHttpPipeline(ChannelPipeline pipeline) { decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR); pipeline.addLast("decoder", decoder); pipeline.addLast("header_verifier", transport.createHeaderVerifier()); - pipeline.addLast("decompress", transport.createDecompressor()); + pipeline.addLast("decoder_compress", new HttpContentDecompressor()); pipeline.addLast("encoder", new HttpResponseEncoder()); final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength()); aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents); pipeline.addLast("aggregator", aggregator); if (handlingSettings.isCompression()) { - pipeline.addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + pipeline.addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } pipeline.addLast("request_creator", requestCreator); pipeline.addLast("response_creator", responseCreator); @@ -497,10 +492,10 @@ protected void initChannel(Channel childChannel) throws Exception { .addLast("byte_buf_sizer", byteBufSizer) .addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)) .addLast("header_verifier", transport.createHeaderVerifier()) - .addLast("decompress", transport.createDecompressor()); + .addLast("decoder_decompress", new HttpContentDecompressor()); if (handlingSettings.isCompression()) { - childChannel.pipeline().addLast("compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); + childChannel.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } childChannel.pipeline() @@ -535,12 +530,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } - protected HttpContentDecompressor createDecompressor() { - return new HttpContentDecompressor(); - } - protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return new Netty4Authorizer(); // pass-through - return new ChannelInboundHandlerAdapter(); +// return new ChannelInboundHandlerAdapter(); } } diff --git a/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java index 61c2fa01c9c06..ed44102d0abe4 100644 --- a/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java @@ -53,7 +53,6 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestHandlerContext; import org.opensearch.rest.RestRequest; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; @@ -100,7 +99,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo protected final ThreadPool threadPool; protected final Dispatcher dispatcher; protected final CorsHandler corsHandler; - protected final NamedXContentRegistry xContentRegistry; + private final NamedXContentRegistry xContentRegistry; protected final PortsRange port; protected final ByteSizeValue maxContentLength; @@ -113,7 +112,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo private final Set httpServerChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); private final HttpTracer httpTracer; - protected final Tracer tracer; + private final Tracer tracer; protected AbstractHttpServerTransport( Settings settings, @@ -360,29 +359,20 @@ protected void serverAcceptedChannel(HttpChannel httpChannel) { * * @param httpRequest that is incoming * @param httpChannel that received the http request - * @param requestContext context carried over to the request handler from earlier stages in the request pipeline */ - public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final RestHandlerContext requestContext) { + public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) { final Span span = tracer.startSpan(SpanBuilder.from(httpRequest), httpRequest.getHeaders()); try (final SpanScope httpRequestSpanScope = tracer.withSpanInScope(span)) { HttpChannel traceableHttpChannel = TraceableHttpChannel.create(httpChannel, span, tracer); - handleIncomingRequest(httpRequest, traceableHttpChannel, requestContext, httpRequest.getInboundException()); + handleIncomingRequest(httpRequest, traceableHttpChannel, httpRequest.getInboundException()); } } // Visible for testing - protected void dispatchRequest( - final RestRequest restRequest, - final RestChannel channel, - final Throwable badRequestCause, - final ThreadContext.StoredContext storedContext - ) { + void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) { RestChannel traceableRestChannel = channel; final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - if (storedContext != null) { - storedContext.restore(); - } final Span span = tracer.startSpan(SpanBuilder.from(restRequest)); try (final SpanScope spanScope = tracer.withSpanInScope(span)) { if (channel != null) { @@ -398,12 +388,7 @@ protected void dispatchRequest( } - private void handleIncomingRequest( - final HttpRequest httpRequest, - final HttpChannel httpChannel, - final RestHandlerContext requestContext, - final Exception exception - ) { + private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { if (exception == null) { HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest); if (earlyResponse != null) { @@ -426,13 +411,13 @@ private void handleIncomingRequest( { RestRequest innerRestRequest; try { - innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel, true); + innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel); } catch (final RestRequest.ContentTypeHeaderException e) { badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause, true); + innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause); } catch (final RestRequest.BadParameterException e) { badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel, true); + innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); } restRequest = innerRestRequest; } @@ -477,84 +462,16 @@ private void handleIncomingRequest( channel = innerChannel; } - if (requestContext.hasEarlyResponse()) { - channel.sendResponse(requestContext.getEarlyResponse()); - return; - } - - dispatchRequest(restRequest, channel, badRequestCause, requestContext.getContextToRestore()); - } - - public static RestRequest createRestRequest( - final NamedXContentRegistry xContentRegistry, - final HttpRequest httpRequest, - final HttpChannel httpChannel - ) { - // TODO Figure out how to only generate one request ID for each request in the pipeline. - Exception badRequestCause = httpRequest.getInboundException(); - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final RestRequest restRequest; - { - RestRequest innerRestRequest; - try { - innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel, false); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = requestWithoutContentTypeHeader(xContentRegistry, httpRequest, httpChannel, badRequestCause, false); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); - innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel, false); - } - restRequest = innerRestRequest; - } - return restRequest; - } - - private static RestRequest requestWithoutContentTypeHeader( - NamedXContentRegistry xContentRegistry, - HttpRequest httpRequest, - HttpChannel httpChannel, - Exception badRequestCause, - boolean shouldGenerateRequestId - ) { - HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); - try { - return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel, shouldGenerateRequestId); - } catch (final RestRequest.BadParameterException e) { - badRequestCause.addSuppressed(e); - return RestRequest.requestWithoutParameters( - xContentRegistry, - httpRequestWithoutContentType, - httpChannel, - shouldGenerateRequestId - ); - } + dispatchRequest(restRequest, channel, badRequestCause); } - private RestRequest requestWithoutContentTypeHeader( - HttpRequest httpRequest, - HttpChannel httpChannel, - Exception badRequestCause, - boolean shouldGenerateRequestId - ) { + private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) { HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); try { - return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel, shouldGenerateRequestId); + return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel); } catch (final RestRequest.BadParameterException e) { badRequestCause.addSuppressed(e); - return RestRequest.requestWithoutParameters( - xContentRegistry, - httpRequestWithoutContentType, - httpChannel, - shouldGenerateRequestId - ); + return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel); } }