diff --git a/lib/limiter/limiter.go b/lib/limiter/limiter.go index 0134d9487fe00..9528caac2bdb5 100644 --- a/lib/limiter/limiter.go +++ b/lib/limiter/limiter.go @@ -180,7 +180,7 @@ func (l *Limiter) StreamServerInterceptor(srv interface{}, serverStream grpc.Ser // WrapListener returns a [Listener] that wraps the provided listener // with one that limits connections -func (l *Limiter) WrapListener(ln net.Listener) *Listener { +func (l *Limiter) WrapListener(ln net.Listener) (*Listener, error) { return NewListener(ln, l.ConnectionsLimiter) } diff --git a/lib/limiter/limiter_test.go b/lib/limiter/limiter_test.go index 1c5208836f98d..cdd48bb9672da 100644 --- a/lib/limiter/limiter_test.go +++ b/lib/limiter/limiter_test.go @@ -407,7 +407,8 @@ func TestListener(t *testing.T) { limiter, err := NewConnectionsLimiter(test.config) require.NoError(t, err) - ln := NewListener(test.listener, limiter) + ln, err := NewListener(test.listener, limiter) + require.NoError(t, err) // open connections without closing to enforce limits conns := make([]net.Conn, 0, connLimit) diff --git a/lib/limiter/listener.go b/lib/limiter/listener.go index 74741ac4d6316..e93325ddc2eee 100644 --- a/lib/limiter/listener.go +++ b/lib/limiter/listener.go @@ -35,11 +35,15 @@ type Listener struct { // NewListener creates a [Listener] that enforces the limits of // the provided [ConnectionsLimiter] on the all connections accepted // by the provided [net.Listener]. -func NewListener(ln net.Listener, limiter *ConnectionsLimiter) *Listener { +func NewListener(ln net.Listener, limiter *ConnectionsLimiter) (*Listener, error) { + if ln == nil { + return nil, trace.BadParameter("listener cannot be nil") + } + return &Listener{ Listener: ln, limiter: limiter, - } + }, nil } // Accept waits for and returns the next connection to the listener diff --git a/lib/service/service.go b/lib/service/service.go index f88754aa54345..5631ac4c55596 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2823,7 +2823,12 @@ func (process *TeleportProcess) initSSH() error { }() defer mux.Close() - go s.Serve(limiter.WrapListener(mux.SSH())) + listener, err = limiter.WrapListener(mux.SSH()) + if err != nil { + return trace.Wrap(err) + } + + go s.Serve(listener) } else { // Start the SSH server. This kicks off updating labels and starting the // heartbeat. @@ -4086,6 +4091,11 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } } + rtListener, err := reverseTunnelLimiter.WrapListener(listeners.reverseTunnel) + if err != nil { + return trace.Wrap(err) + } + tsrv, err = reversetunnel.NewServer( reversetunnel.Config{ Context: process.ExitContext(), @@ -4093,7 +4103,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { ID: process.Config.HostUUID, ClusterName: clusterName, ClientTLS: clientTLSConfig, - Listener: reverseTunnelLimiter.WrapListener(listeners.reverseTunnel), + Listener: rtListener, GetHostSigners: sshutils.StaticHostSigners(conn.ServerIdentity.KeySigner), LocalAuthClient: conn.Client, LocalAccessPoint: accessPoint, @@ -4607,14 +4617,24 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // start ssh server go func() { - if err := sshProxy.Serve(proxyLimiter.WrapListener(listeners.ssh)); err != nil && !utils.IsOKNetworkError(err) { + listener, err := proxyLimiter.WrapListener(listeners.ssh) + if err != nil { + logger.ErrorContext(process.ExitContext(), "Failed to set up SSH proxy server", "error", err) + return + } + if err := sshProxy.Serve(listener); err != nil && !utils.IsOKNetworkError(err) { logger.ErrorContext(process.ExitContext(), "SSH proxy server terminated unexpectedly", "error", err) } }() // start grpc server go func() { - if err := sshGRPCServer.Serve(proxyLimiter.WrapListener(listeners.sshGRPC)); err != nil && !utils.IsOKNetworkError(err) && !errors.Is(err, grpc.ErrServerStopped) { + listener, err := proxyLimiter.WrapListener(listeners.sshGRPC) + if err != nil { + logger.ErrorContext(process.ExitContext(), "Failed to set up SSH proxy server", "error", err) + return + } + if err := sshGRPCServer.Serve(listener); err != nil && !utils.IsOKNetworkError(err) && !errors.Is(err, grpc.ErrServerStopped) { logger.ErrorContext(process.ExitContext(), "SSH gRPC server terminated unexpectedly", "error", err) } }() diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index 7faedf42c0c98..fb1a431105aa6 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -335,7 +335,12 @@ func (s *Server) Start() error { return trace.ConvertSystemError(err) } - if err := s.SetListener(s.limiter.WrapListener(listener)); err != nil { + listener, err = s.limiter.WrapListener(listener) + if err != nil { + return trace.Wrap(err) + } + + if err := s.SetListener(listener); err != nil { return trace.Wrap(err) } }