Skip to content

Commit

Permalink
Check if listener is nil before wrapping. (#43201)
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Jun 21, 2024
1 parent 91df32c commit 6e76f28
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 9 deletions.
2 changes: 1 addition & 1 deletion lib/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (l *Limiter) StreamServerInterceptor(srv interface{}, serverStream grpc.Ser

// WrapListener returns a [Listener] that wraps the provided listener
// with one that limits connections
func (l *Limiter) WrapListener(ln net.Listener) *Listener {
func (l *Limiter) WrapListener(ln net.Listener) (*Listener, error) {
return NewListener(ln, l.ConnectionsLimiter)
}

Expand Down
3 changes: 2 additions & 1 deletion lib/limiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ func TestListener(t *testing.T) {
limiter, err := NewConnectionsLimiter(test.config)
require.NoError(t, err)

ln := NewListener(test.listener, limiter)
ln, err := NewListener(test.listener, limiter)
require.NoError(t, err)

// open connections without closing to enforce limits
conns := make([]net.Conn, 0, connLimit)
Expand Down
8 changes: 6 additions & 2 deletions lib/limiter/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,15 @@ type Listener struct {
// NewListener creates a [Listener] that enforces the limits of
// the provided [ConnectionsLimiter] on the all connections accepted
// by the provided [net.Listener].
func NewListener(ln net.Listener, limiter *ConnectionsLimiter) *Listener {
func NewListener(ln net.Listener, limiter *ConnectionsLimiter) (*Listener, error) {
if ln == nil {
return nil, trace.BadParameter("listener cannot be nil")
}

return &Listener{
Listener: ln,
limiter: limiter,
}
}, nil
}

// Accept waits for and returns the next connection to the listener
Expand Down
28 changes: 24 additions & 4 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2928,7 +2928,12 @@ func (process *TeleportProcess) initSSH() error {
}()
defer mux.Close()

go s.Serve(limiter.WrapListener(mux.SSH()))
listener, err = limiter.WrapListener(mux.SSH())
if err != nil {
return trace.Wrap(err)
}

go s.Serve(listener)
} else {
// Start the SSH server. This kicks off updating labels and starting the
// heartbeat.
Expand Down Expand Up @@ -4197,14 +4202,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
}

rtListener, err := reverseTunnelLimiter.WrapListener(listeners.reverseTunnel)
if err != nil {
return trace.Wrap(err)
}

tsrv, err = reversetunnel.NewServer(
reversetunnel.Config{
Context: process.ExitContext(),
Component: teleport.Component(teleport.ComponentProxy, process.id),
ID: process.Config.HostUUID,
ClusterName: clusterName,
ClientTLS: clientTLSConfig,
Listener: reverseTunnelLimiter.WrapListener(listeners.reverseTunnel),
Listener: rtListener,
GetHostSigners: sshutils.StaticHostSigners(conn.ServerIdentity.KeySigner),
LocalAuthClient: conn.Client,
LocalAccessPoint: accessPoint,
Expand Down Expand Up @@ -4721,14 +4731,24 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

// start ssh server
go func() {
if err := sshProxy.Serve(proxyLimiter.WrapListener(listeners.ssh)); err != nil && !utils.IsOKNetworkError(err) {
listener, err := proxyLimiter.WrapListener(listeners.ssh)
if err != nil {
logger.ErrorContext(process.ExitContext(), "Failed to set up SSH proxy server", "error", err)
return
}
if err := sshProxy.Serve(listener); err != nil && !utils.IsOKNetworkError(err) {
logger.ErrorContext(process.ExitContext(), "SSH proxy server terminated unexpectedly", "error", err)
}
}()

// start grpc server
go func() {
if err := sshGRPCServer.Serve(proxyLimiter.WrapListener(listeners.sshGRPC)); err != nil && !utils.IsOKNetworkError(err) && !errors.Is(err, grpc.ErrServerStopped) {
listener, err := proxyLimiter.WrapListener(listeners.sshGRPC)
if err != nil {
logger.ErrorContext(process.ExitContext(), "Failed to set up SSH proxy server", "error", err)
return
}
if err := sshGRPCServer.Serve(listener); err != nil && !utils.IsOKNetworkError(err) && !errors.Is(err, grpc.ErrServerStopped) {
logger.ErrorContext(process.ExitContext(), "SSH gRPC server terminated unexpectedly", "error", err)
}
}()
Expand Down
7 changes: 6 additions & 1 deletion lib/sshutils/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,12 @@ func (s *Server) Start() error {
return trace.ConvertSystemError(err)
}

if err := s.SetListener(s.limiter.WrapListener(listener)); err != nil {
listener, err = s.limiter.WrapListener(listener)
if err != nil {
return trace.Wrap(err)
}

if err := s.SetListener(listener); err != nil {
return trace.Wrap(err)
}
}
Expand Down

0 comments on commit 6e76f28

Please sign in to comment.