Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tls.Conn.HandshakeContext variant instead of non-context one #40588

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/client/contextdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeou
InsecureSkipVerify: insecure,
ServerName: host,
})
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions api/utils/pingconn/pingconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) {
tlsConnChan <- struct {
*tls.Conn
error
}{tlsConn, tlsConn.Handshake()}
}{tlsConn, tlsConn.HandshakeContext(ctx)}
}()

// Client
Expand All @@ -388,7 +388,7 @@ func makeTLSConn(t *testing.T, server, client net.Conn) (*tls.Conn, *tls.Conn) {
tlsConnChan <- struct {
*tls.Conn
error
}{tlsConn, tlsConn.Handshake()}
}{tlsConn, tlsConn.HandshakeContext(ctx)}
}()

tlsConnSlice := make([]*tls.Conn, 2)
Expand Down
2 changes: 1 addition & 1 deletion lib/multiplexer/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (l *TLSListener) detectAndForward(conn *tls.Conn) {
}

start := l.cfg.Clock.Now()
if err := conn.Handshake(); err != nil {
if err := conn.HandshakeContext(l.context); err != nil {
if !errors.Is(trace.Unwrap(err), io.EOF) {
l.log.WithFields(log.Fields{
"src_addr": conn.RemoteAddr(),
Expand Down
2 changes: 1 addition & 1 deletion lib/multiplexer/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (l *WebListener) detectAndForward(conn *tls.Conn) {
return
}

if err := conn.Handshake(); err != nil {
if err := conn.HandshakeContext(l.context); err != nil {
if !errors.Is(trace.Unwrap(err), io.EOF) {
l.log.WithFields(logrus.Fields{
"src_addr": conn.RemoteAddr(),
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/alpnproxy/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func newBufferedConn(conn net.Conn, header io.Reader) *bufferedConn {
// },
// })
//
// err := tlsConn.Handshake()
// err := tlsConn.HandshakeContext(ctx)
//
// if hello == nil {
// return trace.Wrap(err)
Expand Down
10 changes: 5 additions & 5 deletions lib/srv/alpnproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ type HandlerFuncWithInfo func(ctx context.Context, conn net.Conn, info Connectio
// was set if yes forward to the generic TLS DB handler.
// 6. Forward connection to the handler obtained in step 2.
func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOverride *tls.Config) error {
hello, conn, err := p.readHelloMessageWithoutTLSTermination(clientConn)
hello, conn, err := p.readHelloMessageWithoutTLSTermination(ctx, clientConn)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -391,7 +391,7 @@ func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOver
if err := tlsConn.SetReadDeadline(p.cfg.Clock.Now().Add(p.cfg.ReadDeadline)); err != nil {
return trace.Wrap(err)
}
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return trace.Wrap(err)
}
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
Expand Down Expand Up @@ -497,7 +497,7 @@ func (p *Proxy) getTLSConfig(desc *HandlerDecs, defaultOverride *tls.Config) *tl
// readHelloMessageWithoutTLSTermination allows reading a ClientHelloInfo message without termination of
// incoming TLS connection. After calling readHelloMessageWithoutTLSTermination function a returned
// net.Conn should be used for further operation.
func (p *Proxy) readHelloMessageWithoutTLSTermination(conn net.Conn) (*tls.ClientHelloInfo, net.Conn, error) {
func (p *Proxy) readHelloMessageWithoutTLSTermination(ctx context.Context, conn net.Conn) (*tls.ClientHelloInfo, net.Conn, error) {
buff := new(bytes.Buffer)
var hello *tls.ClientHelloInfo
tlsConn := tls.Server(readOnlyConn{reader: io.TeeReader(conn, buff)}, &tls.Config{
Expand All @@ -513,7 +513,7 @@ func (p *Proxy) readHelloMessageWithoutTLSTermination(conn net.Conn) (*tls.Clien
// Following TLS handshake fails on the server side with error: "no certificates configured" after server
// receives a TLS hello message from the client. If handshake was able to read hello message it indicates successful
// flow otherwise TLS handshake error is returned.
err := tlsConn.Handshake()
err := tlsConn.HandshakeContext(ctx)
if hello == nil {
return nil, nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -545,7 +545,7 @@ func (p *Proxy) databaseHandlerWithTLSTermination(ctx context.Context, conn net.
}
return trace.Wrap(err)
}
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return trace.Wrap(err)
}
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/alpnproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestProxyKubeHandler(t *testing.T) {
kubeCert,
},
})
err := tlsConn.Handshake()
err := tlsConn.HandshakeContext(ctx)
require.NoError(t, err)
_, err = fmt.Fprint(tlsConn, kubernetesHandlerResponse)
require.NoError(t, err)
Expand Down Expand Up @@ -398,7 +398,7 @@ func TestProxyMakeConnectionHandler(t *testing.T) {
})
defer clientTLSConn.Close()

require.NoError(t, clientTLSConn.Handshake())
require.NoError(t, clientTLSConn.HandshakeContext(context.Background()))
require.Equal(t, string(common.ProtocolHTTP), clientTLSConn.ConnectionState().NegotiatedProtocol)
require.NoError(t, req.Write(clientTLSConn))

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/app/connections_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ See https://goteleport.com/docs/access-controls/device-trust/device-management/#
// carry identity in the client certificate.
func (c *ConnectionsHandler) getConnectionInfo(ctx context.Context, conn net.Conn) (*tls.Conn, authz.IdentityGetter, types.Application, error) {
tlsConn := tls.Server(conn, c.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, nil, nil, trace.Wrap(err, "TLS handshake failed")
}

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func TestProxyProtocolPostgresStartup(t *testing.T) {
}
if needsTLSUpgrade {
tlsConn := tls.Client(conn, clientTLSCfg)
require.NoError(t, tlsConn.Handshake())
require.NoError(t, tlsConn.HandshakeContext(ctx))
conn = tlsConn
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func (s *ProxyServer) serveGenericTLS(listener net.Listener, tlsConfig *tls.Conf
go func() {
defer clientConn.Close()
tlsConn := tls.Server(clientConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(s.closeCtx); err != nil {
if !utils.IsOKNetworkError(err) {
s.log.WithError(err).Errorf("%s TLS handshake failed.", dbName)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ func (s *Server) HandleConnection(conn net.Conn) {
// Perform the handshake explicitly, normally it should be performed
// on the first read/write but when the connection is passed over
// reverse tunnel it doesn't happen for some reason.
err := tlsConn.Handshake()
err := tlsConn.HandshakeContext(s.closeContext)
if err != nil {
log.WithError(err).Error("Failed to perform TLS handshake.")
return
Expand Down
2 changes: 1 addition & 1 deletion lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2155,7 +2155,7 @@ func mustStartWindowsDesktopMock(t *testing.T, authClient *auth.Server) *windows
return
}
tlsConn := tls.Server(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(context.Background()); err != nil {
t.Errorf("Unexpected error %v", err)
return
}
Expand Down
Loading