diff --git a/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java b/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java index fb0e584f66..dbaf0420ff 100644 --- a/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java +++ b/src/main/java/org/opensearch/security/dlic/rest/api/AbstractApiAction.java @@ -543,7 +543,6 @@ protected final RestChannelConsumer prepareRequest(RestRequest request, NodeClie return channel -> { final SecurityRequestChannel securityRequest = SecurityRequestFactory.from(request, channel); - // check if .opendistro_security index has been initialized if (!ensureIndexExists()) { internalSeverError(channel, RequestContentValidator.ValidationError.SECURITY_NOT_INITIALIZED.message()); diff --git a/src/main/java/org/opensearch/security/filter/NettyRequest.java b/src/main/java/org/opensearch/security/filter/NettyRequest.java new file mode 100644 index 0000000000..0d9cadb2ca --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyRequest.java @@ -0,0 +1,71 @@ +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.net.ssl.SSLEngine; + +import org.opensearch.rest.RestRequest.Method; + +import io.netty.handler.codec.http.HttpRequest; + +/** + * Wraps the functionality of HttpRequest for use in the security plugin + */ +public class NettyRequest implements SecurityRequest { + protected final HttpRequest underlyingRequest; + + NettyRequest(final HttpRequest request) { + this.underlyingRequest = request; + } + + @Override + public Map> getHeaders() { + final Map> headers = new HashMap<>(); + underlyingRequest.headers().forEach(h -> headers.put(h.getKey(), List.of(h.getValue()))); + return headers; + } + + @Override + public SSLEngine getSSLEngine() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'getSSLEngine'"); + } + + @Override + public String path() { + try { + return new URL(underlyingRequest.uri()).getPath(); + } catch (final MalformedURLException e) { + return ""; + } + } + + @Override + public Method method() { + return Method.valueOf(underlyingRequest.method().name()); + } + + @Override + public Optional getRemoteAddress() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'getRemoteAddress'"); + } + + @Override + public String uri() { + return underlyingRequest.uri(); + } + + @Override + public Map params() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'params'"); + } + +} diff --git a/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java b/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java new file mode 100644 index 0000000000..ac78af1edf --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java @@ -0,0 +1,36 @@ +package org.opensearch.security.filter; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.lang3.tuple.Triple; +import io.netty.handler.codec.http.HttpRequest; + +public class NettyRequestChannel extends NettyRequest implements SecurityRequestChannel { + + private final AtomicReference, String>> completedResult = new AtomicReference<>(); + NettyRequestChannel(final HttpRequest request) { + super(request); + } + + @Override + public boolean hasCompleted() { + return completedResult.get() != null; + } + + @Override + public boolean completeWithResponse(int statusCode, Map headers, String body) { + if (hasCompleted()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + completedResult.set(Triple.of(statusCode, headers, body)); + + return true; + } + + /** Accessor to get the completed response */ + public Triple, String> getCompletedRequest() { + return completedResult.get(); + } +} diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java index beb6103728..0c7e92f4b9 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java @@ -3,6 +3,8 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; +import io.netty.handler.codec.http.HttpRequest; + /** * Generates wrapped versions of requests for use in the security plugin */ @@ -17,4 +19,9 @@ public static SecurityRequest from(final RestRequest request) { public static SecurityRequestChannel from(final RestRequest request, final RestChannel channel) { return new OpenSearchRequestChannel(request, channel); } + + /** Creates a security request from a netty HttpRequest object */ + public static SecurityRequestChannel from(HttpRequest request) { + return new NettyRequestChannel(request); + } } diff --git a/src/main/java/org/opensearch/security/http/AuthenicationVerifier.java b/src/main/java/org/opensearch/security/http/AuthenicationVerifier.java index 9acbfb5141..8d261ce94e 100644 --- a/src/main/java/org/opensearch/security/http/AuthenicationVerifier.java +++ b/src/main/java/org/opensearch/security/http/AuthenicationVerifier.java @@ -1,21 +1,35 @@ package org.opensearch.security.http; +import java.util.Optional; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.security.filter.NettyRequestChannel; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; +import org.opensearch.security.filter.SecurityRestFilter; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import io.netty.util.ReferenceCountUtil; -public class AuthenicationVerifier extends ChannelInboundHandlerAdapter { +public class AuthenticationVerifer extends ChannelInboundHandlerAdapter { + + final static Logger log = LogManager.getLogger(AuthenticationVerifer.class); - final static Logger log = LogManager.getLogger(AuthenicationVerifier.class); + private SecurityRestFilter restFilter; + + public AuthenticationVerifer(SecurityRestFilter restFilter) { + this.restFilter = restFilter; + } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { @@ -23,24 +37,33 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ctx.fireChannelRead(msg); } - HttpRequest request = (HttpRequest) msg; - if (!isAuthenticated(request)) { - final FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); - ReferenceCountUtil.release(msg); + final HttpRequest request = (HttpRequest) msg; + final Optional shouldResponse = getAuthenticationResponse(request); + if (shouldResponse.isPresent()) { + ctx.writeAndFlush(shouldResponse.get()).addListener(ChannelFutureListener.CLOSE); } else { - // Lets the request pass to the next channel handler + // Let the request pass to the next channel handler ctx.fireChannelRead(msg); } } - private boolean isAuthenticated(HttpRequest request) { + private Optional getAuthenticationResponse(HttpRequest request) { log.info("Checking if request is authenticated:\n" + request); - final boolean shouldBlock = request.headers().contains("blockme"); + final NettyRequestChannel requestChannel = (NettyRequestChannel) SecurityRequestFactory.from(request); + restFilter.checkAndAuthenticateRequest(requestChannel); - return !shouldBlock; + if (requestChannel.hasCompleted()) { + final FullHttpResponse response = new DefaultFullHttpResponse( + request.protocolVersion(), + HttpResponseStatus.valueOf(requestChannel.getCompletedRequest().getLeft()), + Unpooled.copiedBuffer(requestChannel.getCompletedRequest().getRight().getBytes())); + requestChannel.getCompletedRequest().getMiddle().forEach((key, value) -> response.headers().set(key, value)); + return Optional.of(response); + } + + return Optional.empty(); } } diff --git a/src/main/java/org/opensearch/security/http/AuthenticationVerifer.java b/src/main/java/org/opensearch/security/http/AuthenticationVerifer.java new file mode 100644 index 0000000000..e56916a96e --- /dev/null +++ b/src/main/java/org/opensearch/security/http/AuthenticationVerifer.java @@ -0,0 +1,74 @@ +package org.opensearch.security.http; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.security.filter.NettyRequestChannel; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestFactory; +import org.opensearch.security.filter.SecurityRestFilter; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.ReferenceCountUtil; + +public class AuthenticationVerifer extends ChannelInboundHandlerAdapter { + + final static Logger log = LogManager.getLogger(AuthenticationVerifer.class); + + private SecurityRestFilter restFilter; + + public AuthenticationVerifer(SecurityRestFilter restFilter) { + this.restFilter = restFilter; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (!(msg instanceof HttpRequest)) { + ctx.fireChannelRead(msg); + } + + final HttpRequest request = (HttpRequest) msg; + final Optional shouldResponse = getAuthenticationResponse(request); + if (shouldResponse.isPresent()) { + ctx.writeAndFlush(shouldResponse.get()).addListener(ChannelFutureListener.CLOSE); + } else { + // Let the request pass to the next channel handler + ctx.fireChannelRead(msg); + } + } + + private Optional getAuthenticationResponse(HttpRequest request) { + + log.info("Checking if request is authenticated:\n" + request); + + final NettyRequestChannel requestChannel = (NettyRequestChannel) SecurityRequestFactory.from(request); + + try { + restFilter.checkAndAuthenticateRequest(requestChannel); + } catch (Exception e) { + log.error(e); + } + + if (requestChannel.hasCompleted()) { + final FullHttpResponse response = new DefaultFullHttpResponse( + request.protocolVersion(), + HttpResponseStatus.valueOf(requestChannel.getCompletedRequest().getLeft()), + Unpooled.copiedBuffer(requestChannel.getCompletedRequest().getRight().getBytes())); + requestChannel.getCompletedRequest().getMiddle().forEach((key, value) -> response.headers().set(key, value)); + return Optional.of(response); + } + + return Optional.empty(); + } + +}