Skip to content

Commit

Permalink
Merge pull request #447 from lutovich/1.5-handshake-timeout
Browse files Browse the repository at this point in the history
Use connect timeout in Bolt and TLS handshake
  • Loading branch information
ali-ince authored Dec 15, 2017
2 parents 7ad54ec + a56490a commit 3eb4c9d
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;

import java.util.Map;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.ConnectionSettings;
import org.neo4j.driver.internal.async.inbound.ConnectTimeoutHandler;
import org.neo4j.driver.internal.security.InternalAuthToken;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.Clock;
Expand Down Expand Up @@ -71,20 +73,47 @@ public ChannelConnectorImpl( ConnectionSettings connectionSettings, SecurityPlan
public ChannelFuture connect( BoltServerAddress address, Bootstrap bootstrap )
{
bootstrap.option( ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis );
bootstrap.handler( new NettyChannelInitializer( address, securityPlan, clock, logging ) );
bootstrap.handler( new NettyChannelInitializer( address, securityPlan, connectTimeoutMillis, clock, logging ) );

ChannelFuture channelConnected = bootstrap.connect( address.toSocketAddress() );

Channel channel = channelConnected.channel();
ChannelPromise handshakeCompleted = channel.newPromise();
ChannelPromise connectionInitialized = channel.newPromise();

installChannelConnectedListeners( address, channelConnected, handshakeCompleted );
installHandshakeCompletedListeners( handshakeCompleted, connectionInitialized );

return connectionInitialized;
}

private void installChannelConnectedListeners( BoltServerAddress address, ChannelFuture channelConnected,
ChannelPromise handshakeCompleted )
{
ChannelPipeline pipeline = channelConnected.channel().pipeline();

// add timeout handler to the pipeline when channel is connected. it's needed to limit amount of time code
// spends in TLS and Bolt handshakes. prevents infinite waiting when database does not respond
channelConnected.addListener( future ->
pipeline.addFirst( new ConnectTimeoutHandler( connectTimeoutMillis ) ) );

// add listener that sends Bolt handshake bytes when channel is connected
channelConnected.addListener(
new ChannelConnectedListener( address, pipelineBuilder, handshakeCompleted, logging ) );
handshakeCompleted.addListener(
new HandshakeCompletedListener( userAgent, authToken, connectionInitialized ) );
}

return connectionInitialized;
private void installHandshakeCompletedListeners( ChannelPromise handshakeCompleted,
ChannelPromise connectionInitialized )
{
ChannelPipeline pipeline = handshakeCompleted.channel().pipeline();

// remove timeout handler from the pipeline once TLS and Bolt handshakes are completed. regular protocol
// messages will flow next and we do not want to have read timeout for them
handshakeCompleted.addListener( future -> pipeline.remove( ConnectTimeoutHandler.class ) );

// add listener that sends an INIT message. connection is now fully established. channel pipeline if fully
// set to send/receive messages for a selected protocol version
handshakeCompleted.addListener( new HandshakeCompletedListener( userAgent, authToken, connectionInitialized ) );
}

private static Map<String,Value> tokenAsMap( AuthToken token )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ public class NettyChannelInitializer extends ChannelInitializer<Channel>
{
private final BoltServerAddress address;
private final SecurityPlan securityPlan;
private final int connectTimeoutMillis;
private final Clock clock;
private final Logging logging;

public NettyChannelInitializer( BoltServerAddress address, SecurityPlan securityPlan, Clock clock, Logging logging )
public NettyChannelInitializer( BoltServerAddress address, SecurityPlan securityPlan, int connectTimeoutMillis,
Clock clock, Logging logging )
{
this.address = address;
this.securityPlan = securityPlan;
this.connectTimeoutMillis = connectTimeoutMillis;
this.clock = clock;
this.logging = logging;
}
Expand All @@ -65,7 +68,9 @@ protected void initChannel( Channel channel )
private SslHandler createSslHandler()
{
SSLEngine sslEngine = createSslEngine();
return new SslHandler( sslEngine );
SslHandler sslHandler = new SslHandler( sslEngine );
sslHandler.setHandshakeTimeoutMillis( connectTimeoutMillis );
return sslHandler;
}

private SSLEngine createSslEngine()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal.async.inbound;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.timeout.ReadTimeoutHandler;

import java.util.concurrent.TimeUnit;

import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;

/**
* Handler needed to limit amount of time connection performs TLS and Bolt handshakes.
* It should only be used when connection is established and removed from the pipeline afterwards.
* Otherwise it will make long running queries fail.
*/
public class ConnectTimeoutHandler extends ReadTimeoutHandler
{
private final long timeoutMillis;
private boolean triggered;

public ConnectTimeoutHandler( long timeoutMillis )
{
super( timeoutMillis, TimeUnit.MILLISECONDS );
this.timeoutMillis = timeoutMillis;
}

@Override
protected void readTimedOut( ChannelHandlerContext ctx )
{
if ( !triggered )
{
triggered = true;
ctx.fireExceptionCaught( unableToConnectError() );
}
}

private ServiceUnavailableException unableToConnectError()
{
return new ServiceUnavailableException( "Unable to establish connection in " + timeoutMillis + "ms" );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,24 @@
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.ssl.SslHandler;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.RuleChain;
import org.junit.rules.Timeout;

import java.io.IOException;
import java.net.ConnectException;
import java.net.ServerSocket;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.ConnectionSettings;
import org.neo4j.driver.internal.async.inbound.ConnectTimeoutHandler;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.FakeClock;
import org.neo4j.driver.v1.AuthToken;
Expand All @@ -42,7 +49,9 @@

import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
Expand All @@ -52,19 +61,20 @@

public class ChannelConnectorImplTest
{
private final TestNeo4j neo4j = new TestNeo4j();
@Rule
public final TestNeo4j neo4j = new TestNeo4j();
public final RuleChain ruleChain = RuleChain.outerRule( Timeout.seconds( 20 ) ).around( neo4j );

private Bootstrap bootstrap;

@Before
public void setUp() throws Exception
public void setUp()
{
bootstrap = BootstrapFactory.newBootstrap( 1 );
}

@After
public void tearDown() throws Exception
public void tearDown()
{
if ( bootstrap != null )
{
Expand All @@ -75,7 +85,7 @@ public void tearDown() throws Exception
@Test
public void shouldConnect() throws Exception
{
ChannelConnectorImpl connector = newConnector( neo4j.authToken() );
ChannelConnector connector = newConnector( neo4j.authToken() );

ChannelFuture channelFuture = connector.connect( neo4j.address(), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );
Expand All @@ -85,10 +95,26 @@ public void shouldConnect() throws Exception
assertTrue( channel.isActive() );
}

@Test
public void shouldSetupHandlers() throws Exception
{
ChannelConnector connector = newConnector( neo4j.authToken(), SecurityPlan.forAllCertificates(), 10_000 );

ChannelFuture channelFuture = connector.connect( neo4j.address(), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );

Channel channel = channelFuture.channel();
ChannelPipeline pipeline = channel.pipeline();
assertTrue( channel.isActive() );

assertNotNull( pipeline.get( SslHandler.class ) );
assertNull( pipeline.get( ConnectTimeoutHandler.class ) );
}

@Test
public void shouldFailToConnectToWrongAddress() throws Exception
{
ChannelConnectorImpl connector = newConnector( neo4j.authToken() );
ChannelConnector connector = newConnector( neo4j.authToken() );

ChannelFuture channelFuture = connector.connect( new BoltServerAddress( "wrong-localhost" ), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );
Expand All @@ -112,7 +138,7 @@ public void shouldFailToConnectToWrongAddress() throws Exception
public void shouldFailToConnectWithWrongCredentials() throws Exception
{
AuthToken authToken = AuthTokens.basic( "neo4j", "wrong-password" );
ChannelConnectorImpl connector = newConnector( authToken );
ChannelConnector connector = newConnector( authToken );

ChannelFuture channelFuture = connector.connect( neo4j.address(), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );
Expand All @@ -131,10 +157,10 @@ public void shouldFailToConnectWithWrongCredentials() throws Exception
assertFalse( channel.isActive() );
}

@Test( timeout = 10000 )
@Test
public void shouldEnforceConnectTimeout() throws Exception
{
ChannelConnectorImpl connector = newConnector( neo4j.authToken(), 1000 );
ChannelConnector connector = newConnector( neo4j.authToken(), 1000 );

// try connect to a non-routable ip address 10.0.0.0, it will never respond
ChannelFuture channelFuture = connector.connect( new BoltServerAddress( "10.0.0.0" ), bootstrap );
Expand All @@ -151,15 +177,55 @@ public void shouldEnforceConnectTimeout() throws Exception
}
}

@Test
public void shouldFailWhenProtocolNegotiationTakesTooLong() throws Exception
{
// run without TLS so that Bolt handshake is the very first operation after connection is established
testReadTimeoutOnConnect( SecurityPlan.insecure() );
}

@Test
public void shouldFailWhenTLSHandshakeTakesTooLong() throws Exception
{
// run with TLS so that TLS handshake is the very first operation after connection is established
testReadTimeoutOnConnect( SecurityPlan.forAllCertificates() );
}

private void testReadTimeoutOnConnect( SecurityPlan securityPlan ) throws IOException
{
try ( ServerSocket server = new ServerSocket( 0 ) ) // server that accepts connections but does not reply
{
int timeoutMillis = 1_000;
BoltServerAddress address = new BoltServerAddress( "localhost", server.getLocalPort() );
ChannelConnector connector = newConnector( neo4j.authToken(), securityPlan, timeoutMillis );

ChannelFuture channelFuture = connector.connect( address, bootstrap );
try
{
await( channelFuture );
fail( "Exception expected" );
}
catch ( ServiceUnavailableException e )
{
assertEquals( e.getMessage(), "Unable to establish connection in " + timeoutMillis + "ms" );
}
}
}

private ChannelConnectorImpl newConnector( AuthToken authToken ) throws Exception
{
return newConnector( authToken, Integer.MAX_VALUE );
}

private ChannelConnectorImpl newConnector( AuthToken authToken, int connectTimeoutMillis ) throws Exception
{
ConnectionSettings settings = new ConnectionSettings( authToken, 1000 );
return new ChannelConnectorImpl( settings, SecurityPlan.forAllCertificates(), DEV_NULL_LOGGING,
new FakeClock() );
return newConnector( authToken, SecurityPlan.forAllCertificates(), connectTimeoutMillis );
}

private ChannelConnectorImpl newConnector( AuthToken authToken, SecurityPlan securityPlan,
int connectTimeoutMillis )
{
ConnectionSettings settings = new ConnectionSettings( authToken, connectTimeoutMillis );
return new ChannelConnectorImpl( settings, securityPlan, DEV_NULL_LOGGING, new FakeClock() );
}
}
Loading

0 comments on commit 3eb4c9d

Please sign in to comment.