diff --git a/influent-java-sample/src/main/java/sample/TLSPrint.java b/influent-java-sample/src/main/java/sample/TLSPrint.java index f5da3a2..7a09a3a 100644 --- a/influent-java-sample/src/main/java/sample/TLSPrint.java +++ b/influent-java-sample/src/main/java/sample/TLSPrint.java @@ -38,7 +38,7 @@ public static void main(final String[] args) throws Exception { final ForwardServer server = new ForwardServer.Builder(callback) .sslEnabled(true) - .tlsVersions(new String[]{"TLSv1.2"}) + .tlsVersions("TLSv1.2") .keystorePath("./out/influent-server.jks") .keystorePassword("password") .keyPassword("password") diff --git a/influent-java/src/main/java/influent/forward/ForwardServer.java b/influent-java/src/main/java/influent/forward/ForwardServer.java index e663217..730a9dc 100644 --- a/influent-java/src/main/java/influent/forward/ForwardServer.java +++ b/influent-java/src/main/java/influent/forward/ForwardServer.java @@ -203,7 +203,7 @@ public Builder sslEnabled(final boolean value) { * @param value the TLS versions. Available elements are "TLS", "TLSv1", "TLSv1.1" or "TLSv1.2" * @return this builder */ - public Builder tlsVersions(final String[] value) { + public Builder tlsVersions(final String... value) { tlsVersions = value; return this; } @@ -276,8 +276,8 @@ public ForwardServer build() { receiveBufferSize, keepAliveEnabled, tcpNoDelayEnabled, - workerPoolSize == 0 ? DEFAULT_WORKER_POOL_SIZE : workerPoolSize - // TODO Add channelConfig here + workerPoolSize == 0 ? DEFAULT_WORKER_POOL_SIZE : workerPoolSize, + channelConfig ); } } diff --git a/influent-java/src/main/java/influent/forward/NioForwardServer.java b/influent-java/src/main/java/influent/forward/NioForwardServer.java index 521180c..6676d71 100644 --- a/influent-java/src/main/java/influent/forward/NioForwardServer.java +++ b/influent-java/src/main/java/influent/forward/NioForwardServer.java @@ -22,6 +22,7 @@ import java.util.concurrent.ThreadFactory; import java.util.function.Consumer; +import influent.internal.nio.NioChannelConfig; import influent.internal.nio.NioEventLoop; import influent.internal.nio.NioEventLoopPool; import influent.internal.nio.NioTcpAcceptor; @@ -57,13 +58,22 @@ final class NioForwardServer implements ForwardServer { final int receiveBufferSize, final boolean keepAliveEnabled, final boolean tcpNoDelayEnabled, - final int workerPoolSize) { + final int workerPoolSize, + final NioChannelConfig channelConfig) { bossEventLoop = NioEventLoop.open(); workerEventLoopPool = NioEventLoopPool.open(workerPoolSize); - final Consumer channelFactory = socketChannel -> new NioForwardConnection( - socketChannel, workerEventLoopPool.next(), callback, chunkSizeLimit, sendBufferSize, - keepAliveEnabled, tcpNoDelayEnabled - ); + final Consumer channelFactory; + if (channelConfig.isSslEnabled()) { + channelFactory = socketChannel -> new NioSslForwardConnection( + socketChannel, workerEventLoopPool.next(), callback, channelConfig.createSSLEngine(), + chunkSizeLimit, sendBufferSize, keepAliveEnabled, tcpNoDelayEnabled + ); + } else { + channelFactory = socketChannel -> new NioForwardConnection( + socketChannel, workerEventLoopPool.next(), callback, + chunkSizeLimit, sendBufferSize, keepAliveEnabled, tcpNoDelayEnabled + ); + } new NioTcpAcceptor( localAddress, bossEventLoop, channelFactory, backlog, receiveBufferSize ); diff --git a/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java b/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java new file mode 100644 index 0000000..3fac24f --- /dev/null +++ b/influent-java/src/main/java/influent/forward/NioSslForwardConnection.java @@ -0,0 +1,351 @@ +/* + * Copyright 2016 okumin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package influent.forward; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.LinkedList; +import java.util.Queue; +import java.util.function.Supplier; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; + +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import influent.exception.InfluentIOException; +import influent.internal.msgpack.MsgpackStreamUnpacker; +import influent.internal.nio.NioAttachment; +import influent.internal.nio.NioEventLoop; +import influent.internal.nio.NioTcpChannel; +import influent.internal.util.ThreadSafeQueue; + +/** + * A connection for SSL/TLS forward protocol. + */ +final class NioSslForwardConnection implements NioAttachment { + private static final Logger logger = LoggerFactory.getLogger(NioSslForwardConnection.class); + private static final String ACK_KEY = "ack"; + + private final NioTcpChannel channel; + private final NioEventLoop eventLoop; + private final ForwardCallback callback; + private final SSLEngine engine; + private final MsgpackStreamUnpacker unpacker; + private final MsgpackForwardRequestDecoder decoder; + + private final ThreadSafeQueue responses = new ThreadSafeQueue<>(); + + // Prepare a ByteBuffer with sufficient size + private ByteBuffer inboundNetworkBuffer = ByteBuffer.allocate(1024 * 1024); + private final Queue outboundNetworkBuffers = new LinkedList<>(); + + NioSslForwardConnection(final NioTcpChannel channel, + final NioEventLoop eventLoop, + final ForwardCallback callback, + final SSLEngine engine, + final MsgpackStreamUnpacker unpacker, + final MsgpackForwardRequestDecoder decoder) { + this.channel = channel; + this.eventLoop = eventLoop; + this.callback = callback; + this.engine = engine; + this.unpacker = unpacker; + this.decoder = decoder; + inboundNetworkBuffer.position(inboundNetworkBuffer.limit()); + } + + NioSslForwardConnection(final NioTcpChannel channel, + final NioEventLoop eventLoop, + final ForwardCallback callback, + final SSLEngine engine, + final long chunkSizeLimit) { + this( + channel, + eventLoop, + callback, + engine, + new MsgpackStreamUnpacker(chunkSizeLimit), + new MsgpackForwardRequestDecoder() + ); + } + + /** + * Constructs a new {@code NioSslForwardConnection}. + * + * @param socketChannel the inbound channel + * @param eventLoop the {@code NioEventLoop} to which this {@code NioSslForwardConnection} belongs + * @param callback the callback to handle requests + * @param chunkSizeLimit the allowable size of a chunk + * @param sendBufferSize enqueue buffer size + * the default value is used when the given {@code value} is empty + * @param keepAliveEnabled whether SO_KEEPALIVE is enabled or not + * @param tcpNoDelayEnabled whether TCP_NODELAY is enabled or not + * @throws InfluentIOException if some IO error occurs + */ + NioSslForwardConnection(final SocketChannel socketChannel, + final NioEventLoop eventLoop, + final ForwardCallback callback, + final SSLEngine engine, + final long chunkSizeLimit, + final int sendBufferSize, + final boolean keepAliveEnabled, + final boolean tcpNoDelayEnabled) { + this( + new NioTcpChannel(socketChannel, sendBufferSize, keepAliveEnabled, tcpNoDelayEnabled), + eventLoop, + callback, + engine, + chunkSizeLimit + ); + + channel.register(eventLoop, SelectionKey.OP_READ, this); + } + + /** + * Handles a write event. + * + * @param key the {@code SelectionKey} + * @throws InfluentIOException if some IO error occurs + */ + @Override + public void onWritable(final SelectionKey key) { + if (!handshake(key)) { + if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) { + eventLoop.enableInterestSet(key, SelectionKey.OP_WRITE); + } + return; + } + + while (responses.nonEmpty()) { + final ByteBuffer head = responses.dequeue(); + wrapAndSend(key, head); + } + if (!channel.isOpen()) { + close(); + } + } + + /** + * Handles a read event. + * + * @param key the {@code SelectionKey} + * @throws InfluentIOException if some IO error occurs + */ + @Override + public void onReadable(final SelectionKey key) { + if (!handshake(key)) { + if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) { + eventLoop.enableInterestSet(key, SelectionKey.OP_WRITE); + } + return; + } + + receiveRequests(key); + if (!channel.isOpen()) { + close(); + } + } + + private void receiveRequests(final SelectionKey key) { + // TODO: optimize + final Supplier supplier = () -> { + final ByteBuffer buffer = ByteBuffer.allocate(1024 * 1024); + receiveAndUnwrap(buffer); + buffer.flip(); + if (!buffer.hasRemaining()) { + return null; + } + return buffer; + }; + unpacker.feed(supplier, channel); + while (unpacker.hasNext()) { + try { + decoder.decode(unpacker.next()).ifPresent(result -> { + logger.debug( + "Received a forward request from {}. chunk_id = {}", + channel.getRemoteAddress(), result.getOption() + ); + callback.consume(result.getStream()).thenRun(() -> { + // Executes on user's callback thread since the queue never block. + result.getOption().getChunk().ifPresent(chunk -> completeTask(key, chunk)); + logger.debug("Completed the task. chunk_id = {}.", result.getOption()); + }); + }); + } catch (final IllegalArgumentException e) { + logger.error( + "Received an invalid message. remote address = " + channel.getRemoteAddress(), e + ); + } + } + } + + // This method is thread-safe. + private void completeTask(final SelectionKey key, final String chunk) { + try { + final MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + packer.packMapHeader(1); + packer.packString(ACK_KEY); + packer.packString(chunk); + final ByteBuffer buffer = packer.toMessageBuffer().sliceAsByteBuffer(); + responses.enqueue(buffer); + eventLoop.enableInterestSet(key, SelectionKey.OP_WRITE); + } catch (final IOException e) { + logger.error("Failed packing. chunk = " + chunk, e); + } + } + + // true when the handshake is completed + private boolean handshake(final SelectionKey key) { + final SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); + logger.debug("Current handshake status: " + handshakeStatus); + if (!isHandshaking(handshakeStatus)) { + return true; + } + + switch (handshakeStatus) { + case NEED_UNWRAP: + return receiveAndUnwrap(ByteBuffer.allocate(1024 * 1024)) && handshake(key); + case NEED_WRAP: + return wrapAndSend(key, ByteBuffer.allocate(0)) && handshake(key); + case NEED_TASK: + while (true) { + final Runnable task = engine.getDelegatedTask(); + if (task == null) { + break; + } + task.run(); + } + return handshake(key); + case FINISHED: + case NOT_HANDSHAKING: + default: + throw new AssertionError(); + } + } + + private boolean wrapAndSend(final SelectionKey key, final ByteBuffer src) { + try { + final ByteBuffer buffer = ByteBuffer.allocate(1024 * 1024); + final SSLEngineResult result = engine.wrap(src, buffer); + switch (result.getStatus()) { + case OK: + break; + case CLOSED: + close(); + break; + case BUFFER_OVERFLOW: + case BUFFER_UNDERFLOW: + default: + throw new AssertionError(); + } + + buffer.flip(); + if (buffer.hasRemaining()) { + outboundNetworkBuffers.add(buffer); + } + while (!outboundNetworkBuffers.isEmpty()) { + final ByteBuffer head = outboundNetworkBuffers.peek(); + final int bytes = channel.write(head); + if (bytes == 0) { + break; + } + if (!head.hasRemaining()) { + outboundNetworkBuffers.poll(); + } + } + if (outboundNetworkBuffers.isEmpty() && key.isWritable()) { + eventLoop.disableInterestSet(key, SelectionKey.OP_WRITE); + } + return outboundNetworkBuffers.isEmpty(); + } catch (final SSLException e) { + throw new InfluentIOException("Illegal SSL/TLS processing was detected.", e); + } catch (final ReadOnlyBufferException | IllegalArgumentException | IllegalStateException e) { + throw new AssertionError(e); + } + } + + private boolean receiveAndUnwrap(final ByteBuffer dst) { + try { + if (!inboundNetworkBuffer.hasRemaining()) { + inboundNetworkBuffer.clear(); + inboundNetworkBuffer.mark(); + } else { + inboundNetworkBuffer.mark(); + inboundNetworkBuffer.position(inboundNetworkBuffer.limit()); + inboundNetworkBuffer.limit(inboundNetworkBuffer.capacity()); + } + final int bytes = channel.read(inboundNetworkBuffer); + inboundNetworkBuffer.limit(inboundNetworkBuffer.position()); + inboundNetworkBuffer.reset(); + if (!inboundNetworkBuffer.hasRemaining()) { + return false; + } + while (inboundNetworkBuffer.hasRemaining()) { + final int start = dst.position(); + final SSLEngineResult result = engine.unwrap(inboundNetworkBuffer, dst); + switch (result.getStatus()) { + case OK: + if (dst.position() == start) { + return true; + } + break; + case BUFFER_UNDERFLOW: + return bytes != 0; + case CLOSED: + close(); + if (dst.position() == start) { + return false; + } + break; + case BUFFER_OVERFLOW: + default: + throw new AssertionError(); + } + } + return true; + } catch (final SSLException e) { + throw new InfluentIOException("Illegal SSL/TLS processing was detected.", e); + } catch (final ReadOnlyBufferException | IllegalArgumentException | IllegalStateException e) { + throw new AssertionError(e); + } + } + + private static boolean isHandshaking(final SSLEngineResult.HandshakeStatus status) { + return status != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING + && status != SSLEngineResult.HandshakeStatus.FINISHED; + } + + @Override + public void close() { + // TODO: graceful stop + channel.close(); + logger.debug("NioSslForwardConnection bound with {} closed.", channel.getRemoteAddress()); + } + + @Override + public String toString() { + return "NioSslForwardConnection(" + channel.getRemoteAddress() + ")"; + } +} diff --git a/influent-java/src/test/scala/influent/internal/msgpack/MsgpackStreamUnpackerSpec.scala b/influent-java/src/test/scala/influent/internal/msgpack/MsgpackStreamUnpackerSpec.scala index a6a0463..8db216c 100644 --- a/influent-java/src/test/scala/influent/internal/msgpack/MsgpackStreamUnpackerSpec.scala +++ b/influent-java/src/test/scala/influent/internal/msgpack/MsgpackStreamUnpackerSpec.scala @@ -143,7 +143,7 @@ class MsgpackStreamUnpackerSpec val supplier = new Supplier[ByteBuffer] { override def get(): ByteBuffer = { if (groupedBytes.hasNext) { - val buffer = ByteBuffer.allocate(1024 * 1024) + val buffer = ByteBuffer.allocate(1024 * 16) buffer.put(groupedBytes.next()).flip() buffer } else { diff --git a/influent-transport/src/main/java/influent/internal/nio/NioChannelConfig.java b/influent-transport/src/main/java/influent/internal/nio/NioChannelConfig.java index a6c7762..87929e6 100644 --- a/influent-transport/src/main/java/influent/internal/nio/NioChannelConfig.java +++ b/influent-transport/src/main/java/influent/internal/nio/NioChannelConfig.java @@ -31,6 +31,9 @@ import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +import influent.exception.InfluentIOException; public class NioChannelConfig { @@ -79,6 +82,11 @@ public SSLEngine createSSLEngine() { if (ciphers != null) { engine.setEnabledCipherSuites(ciphers); } + try { + engine.beginHandshake(); + } catch (final SSLException e) { + throw new InfluentIOException("Failed beginning a handshake.", e); + } // TODO configure engine return engine; }