From 0f6d09d600a4430825a30480a01896379095f130 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Tue, 22 Oct 2024 20:35:47 +0100 Subject: [PATCH] [kube] add server_id to targets when monitoring exec/portforward connections This PR adds the target server_id (kubernetes service) when proxy establishes a connection to support kubectl exec and portforward. This allows proxies to terminate early the connection without relying on the upstream to terminate it. --- lib/kube/proxy/forwarder.go | 22 +++++++++++++++++----- lib/kube/proxy/roundtrip.go | 1 - lib/kube/proxy/transport.go | 23 ++++++++++++++++++++--- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 98160c470d98d..ba0beb4369792 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -2176,7 +2176,7 @@ func (s *clusterSession) close() { } } -func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error) { +func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (net.Conn, error) { if err != nil { return nil, trace.Wrap(err) } @@ -2191,10 +2191,18 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error) s.connMonitorCancel(err) return nil, trace.Wrap(err) } - + lockTargets := s.LockTargets() + // when the target is not a kubernetes_service instance, we don't need to lock it. + // the target could be a remote cluster or a local Kubernetes API server. In both cases, + // hostID is empty. + if hostID != "" { + lockTargets = append(lockTargets, types.LockTarget{ + ServerID: hostID, + }) + } err = srv.StartMonitor(srv.MonitorConfig{ LockWatcher: s.parent.cfg.LockWatcher, - LockTargets: s.LockTargets(), + LockTargets: lockTargets, DisconnectExpiredCert: s.disconnectExpiredCert, ClientIdleTimeout: s.clientIdleTimeout, Clock: s.parent.cfg.Clock, @@ -2226,12 +2234,16 @@ func (s *clusterSession) getServerMetadata() apievents.ServerMetadata { } func (s *clusterSession) Dial(network, addr string) (net.Conn, error) { - return s.monitorConn(s.dial(s.requestContext, network, addr)) + var hostID string + conn, err := s.dial(s.requestContext, network, addr, withHostIDCollection(&hostID)) + return s.monitorConn(conn, err, hostID) } func (s *clusterSession) DialWithContext(opts ...contextDialerOption) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { - return s.monitorConn(s.dial(ctx, network, addr, opts...)) + var hostID string + conn, err := s.dial(ctx, network, addr, append(opts, withHostIDCollection(&hostID))...) + return s.monitorConn(conn, err, hostID) } } diff --git a/lib/kube/proxy/roundtrip.go b/lib/kube/proxy/roundtrip.go index f5507170ffecc..9049dabb61488 100644 --- a/lib/kube/proxy/roundtrip.go +++ b/lib/kube/proxy/roundtrip.go @@ -110,7 +110,6 @@ func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { if err != nil { return nil, err } - if err := req.Write(conn); err != nil { conn.Close() return nil, err diff --git a/lib/kube/proxy/transport.go b/lib/kube/proxy/transport.go index ec5115501a2d2..e426a30901c12 100644 --- a/lib/kube/proxy/transport.go +++ b/lib/kube/proxy/transport.go @@ -493,6 +493,7 @@ func (f *Forwarder) localClusterDialer(kubeClusterName string, opts ...contextDi ProxyIDs: s.GetProxyIDs(), }) if err == nil { + opt.collect(s.GetHostID()) return conn, nil } errs = append(errs, trace.Wrap(err)) @@ -584,13 +585,21 @@ func (f *Forwarder) getContextDialerFunc(s *clusterSession, opts ...contextDiale // contextDialerOptions is a set of options that can be used to filter // the hosts that the dialer connects to. type contextDialerOptions struct { - hostID string + hostIDFilter string + collectHostID *string } // matches returns true if the host matches the hostID of the dialer options or // if the dialer hostID is empty. func (c *contextDialerOptions) matches(hostID string) bool { - return c.hostID == "" || c.hostID == hostID + return c.hostIDFilter == "" || c.hostIDFilter == hostID +} + +// collect sets the hostID that the dialer connected to if collectHostID is not nil. +func (c *contextDialerOptions) collect(hostID string) { + if c.collectHostID != nil { + *c.collectHostID = hostID + } } // contextDialerOption is a functional option for the contextDialerOptions. @@ -603,6 +612,14 @@ type contextDialerOption func(*contextDialerOptions) // error. func withTargetHostID(hostID string) contextDialerOption { return func(o *contextDialerOptions) { - o.hostID = hostID + o.hostIDFilter = hostID + } +} + +// withHostIDCollection is a functional option that sets the hostID of the dialer +// to the provided pointer. +func withHostIDCollection(hostID *string) contextDialerOption { + return func(o *contextDialerOptions) { + o.collectHostID = hostID } }