From 5927ba1b3469a7e980cd0259ae1f70843f54fb87 Mon Sep 17 00:00:00 2001 From: Marco Dinis Date: Thu, 25 Apr 2024 10:39:05 +0100 Subject: [PATCH] Use HandshakeContext instead of Handshake for TLS Connections set up --- api/client/contextdialer.go | 2 +- api/utils/pingconn/pingconn_test.go | 4 ++-- lib/multiplexer/tls.go | 2 +- lib/multiplexer/web.go | 2 +- lib/srv/alpnproxy/conn.go | 2 +- lib/srv/alpnproxy/proxy.go | 10 +++++----- lib/srv/alpnproxy/proxy_test.go | 4 ++-- lib/srv/app/connections_handler.go | 2 +- lib/srv/db/proxy_test.go | 2 +- lib/srv/db/proxyserver.go | 2 +- lib/srv/db/server.go | 2 +- lib/web/apiserver_test.go | 2 +- 12 files changed, 18 insertions(+), 18 deletions(-) diff --git a/api/client/contextdialer.go b/api/client/contextdialer.go index e0fb6d9a6aa0e..f6f69625189b5 100644 --- a/api/client/contextdialer.go +++ b/api/client/contextdialer.go @@ -320,7 +320,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeou InsecureSkipVerify: insecure, ServerName: host, }) - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, trace.Wrap(err) } diff --git a/api/utils/pingconn/pingconn_test.go b/api/utils/pingconn/pingconn_test.go index 7514c7dd8cbe3..60d7eed84b7de 100644 --- a/api/utils/pingconn/pingconn_test.go +++ b/api/utils/pingconn/pingconn_test.go @@ -379,7 +379,7 @@ func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) { tlsConnChan <- struct { *tls.Conn error - }{tlsConn, tlsConn.Handshake()} + }{tlsConn, tlsConn.HandshakeContext(ctx)} }() // Client @@ -388,7 +388,7 @@ func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) { tlsConnChan <- struct { *tls.Conn error - }{tlsConn, tlsConn.Handshake()} + }{tlsConn, tlsConn.HandshakeContext(ctx)} }() tlsConnSlice := make([]*tls.Conn, 2) diff --git a/lib/multiplexer/tls.go b/lib/multiplexer/tls.go index c3d2046ee1007..8742acf0760a1 100644 --- a/lib/multiplexer/tls.go +++ b/lib/multiplexer/tls.go @@ -145,7 +145,7 @@ func (l *TLSListener) detectAndForward(conn *tls.Conn) { } start := l.cfg.Clock.Now() - if err := conn.Handshake(); err != nil { + if err := conn.HandshakeContext(l.context); err != nil { if !errors.Is(trace.Unwrap(err), io.EOF) { l.log.WithFields(log.Fields{ "src_addr": conn.RemoteAddr(), diff --git a/lib/multiplexer/web.go b/lib/multiplexer/web.go index 294ad762c6e80..e4e9a3d74c634 100644 --- a/lib/multiplexer/web.go +++ b/lib/multiplexer/web.go @@ -137,7 +137,7 @@ func (l *WebListener) detectAndForward(conn *tls.Conn) { return } - if err := conn.Handshake(); err != nil { + if err := conn.HandshakeContext(l.context); err != nil { if !errors.Is(trace.Unwrap(err), io.EOF) { l.log.WithFields(logrus.Fields{ "src_addr": conn.RemoteAddr(), diff --git a/lib/srv/alpnproxy/conn.go b/lib/srv/alpnproxy/conn.go index 712691c0a1b27..188037a54761b 100644 --- a/lib/srv/alpnproxy/conn.go +++ b/lib/srv/alpnproxy/conn.go @@ -66,7 +66,7 @@ func newBufferedConn(conn net.Conn, header io.Reader) *bufferedConn { // }, // }) // -// err := tlsConn.Handshake() +// err := tlsConn.HandshakeContext(ctx) // // if hello == nil { // return trace.Wrap(err) diff --git a/lib/srv/alpnproxy/proxy.go b/lib/srv/alpnproxy/proxy.go index 7325ec1378380..52b963891b658 100644 --- a/lib/srv/alpnproxy/proxy.go +++ b/lib/srv/alpnproxy/proxy.go @@ -367,7 +367,7 @@ type HandlerFuncWithInfo func(ctx context.Context, conn net.Conn, info Connectio // was set if yes forward to the generic TLS DB handler. // 6. Forward connection to the handler obtained in step 2. func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOverride *tls.Config) error { - hello, conn, err := p.readHelloMessageWithoutTLSTermination(clientConn) + hello, conn, err := p.readHelloMessageWithoutTLSTermination(ctx, clientConn) if err != nil { return trace.Wrap(err) } @@ -391,7 +391,7 @@ func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOver if err := tlsConn.SetReadDeadline(p.cfg.Clock.Now().Add(p.cfg.ReadDeadline)); err != nil { return trace.Wrap(err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { return trace.Wrap(err) } if err := tlsConn.SetReadDeadline(time.Time{}); err != nil { @@ -497,7 +497,7 @@ func (p *Proxy) getTLSConfig(desc *HandlerDecs, defaultOverride *tls.Config) *tl // readHelloMessageWithoutTLSTermination allows reading a ClientHelloInfo message without termination of // incoming TLS connection. After calling readHelloMessageWithoutTLSTermination function a returned // net.Conn should be used for further operation. -func (p *Proxy) readHelloMessageWithoutTLSTermination(conn net.Conn) (*tls.ClientHelloInfo, net.Conn, error) { +func (p *Proxy) readHelloMessageWithoutTLSTermination(ctx context.Context, conn net.Conn) (*tls.ClientHelloInfo, net.Conn, error) { buff := new(bytes.Buffer) var hello *tls.ClientHelloInfo tlsConn := tls.Server(readOnlyConn{reader: io.TeeReader(conn, buff)}, &tls.Config{ @@ -513,7 +513,7 @@ func (p *Proxy) readHelloMessageWithoutTLSTermination(conn net.Conn) (*tls.Clien // Following TLS handshake fails on the server side with error: "no certificates configured" after server // receives a TLS hello message from the client. If handshake was able to read hello message it indicates successful // flow otherwise TLS handshake error is returned. - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(ctx) if hello == nil { return nil, nil, trace.Wrap(err) } @@ -545,7 +545,7 @@ func (p *Proxy) databaseHandlerWithTLSTermination(ctx context.Context, conn net. } return trace.Wrap(err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { return trace.Wrap(err) } if err := tlsConn.SetReadDeadline(time.Time{}); err != nil { diff --git a/lib/srv/alpnproxy/proxy_test.go b/lib/srv/alpnproxy/proxy_test.go index 737b3b660dcbf..922c3dcc93ad7 100644 --- a/lib/srv/alpnproxy/proxy_test.go +++ b/lib/srv/alpnproxy/proxy_test.go @@ -90,7 +90,7 @@ func TestProxyKubeHandler(t *testing.T) { kubeCert, }, }) - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(ctx) require.NoError(t, err) _, err = fmt.Fprint(tlsConn, kubernetesHandlerResponse) require.NoError(t, err) @@ -398,7 +398,7 @@ func TestProxyMakeConnectionHandler(t *testing.T) { }) defer clientTLSConn.Close() - require.NoError(t, clientTLSConn.Handshake()) + require.NoError(t, clientTLSConn.HandshakeContext(context.Background())) require.Equal(t, string(common.ProtocolHTTP), clientTLSConn.ConnectionState().NegotiatedProtocol) require.NoError(t, req.Write(clientTLSConn)) diff --git a/lib/srv/app/connections_handler.go b/lib/srv/app/connections_handler.go index f070ce114c132..98d846aba7adb 100644 --- a/lib/srv/app/connections_handler.go +++ b/lib/srv/app/connections_handler.go @@ -728,7 +728,7 @@ See https://goteleport.com/docs/access-controls/device-trust/device-management/# // carry identity in the client certificate. func (c *ConnectionsHandler) getConnectionInfo(ctx context.Context, conn net.Conn) (*tls.Conn, authz.IdentityGetter, types.Application, error) { tlsConn := tls.Server(conn, c.tlsConfig) - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, nil, nil, trace.Wrap(err, "TLS handshake failed") } diff --git a/lib/srv/db/proxy_test.go b/lib/srv/db/proxy_test.go index 29cad898ed99b..afb9d06b8c543 100644 --- a/lib/srv/db/proxy_test.go +++ b/lib/srv/db/proxy_test.go @@ -252,7 +252,7 @@ func TestProxyProtocolPostgresStartup(t *testing.T) { } if needsTLSUpgrade { tlsConn := tls.Client(conn, clientTLSCfg) - require.NoError(t, tlsConn.Handshake()) + require.NoError(t, tlsConn.HandshakeContext(ctx)) conn = tlsConn } } diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index e94e290726674..0bd5a432845e6 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -276,7 +276,7 @@ func (s *ProxyServer) serveGenericTLS(listener net.Listener, tlsConfig *tls.Conf go func() { defer clientConn.Close() tlsConn := tls.Server(clientConn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(s.closeCtx); err != nil { if !utils.IsOKNetworkError(err) { s.log.WithError(err).Errorf("%s TLS handshake failed.", dbName) } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 6ab2cf5d44a06..45962e198500e 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -937,7 +937,7 @@ func (s *Server) HandleConnection(conn net.Conn) { // Perform the handshake explicitly, normally it should be performed // on the first read/write but when the connection is passed over // reverse tunnel it doesn't happen for some reason. - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(s.closeContext) if err != nil { log.WithError(err).Error("Failed to perform TLS handshake.") return diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index ebf5701d2d59c..51b5665c53557 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -2155,7 +2155,7 @@ func mustStartWindowsDesktopMock(t *testing.T, authClient *auth.Server) *windows return } tlsConn := tls.Server(conn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(context.Background()); err != nil { t.Errorf("Unexpected error %v", err) return }