From 3ba6df5927ec7f4da2b5699c2fc800160d1bed99 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Thu, 4 Apr 2024 12:55:48 +0100 Subject: [PATCH] Reduce flakiness of `TestBotDatabaseTunnel` (#40179) * Reduce flakiness of `TestBotDatabaseTunnel` * Provide listener in test --- lib/tbot/config/service_database_tunnel.go | 7 ++++- lib/tbot/service_database_tunnel.go | 29 +++++++++++---------- lib/tbot/tbot_test.go | 30 ++++++++++++++-------- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/lib/tbot/config/service_database_tunnel.go b/lib/tbot/config/service_database_tunnel.go index 4b0497f7a5857..be035a5091ff4 100644 --- a/lib/tbot/config/service_database_tunnel.go +++ b/lib/tbot/config/service_database_tunnel.go @@ -19,6 +19,7 @@ package config import ( + "net" "net/url" "github.com/gravitational/trace" @@ -44,6 +45,10 @@ type DatabaseTunnelService struct { Database string `yaml:"database"` // Username is the database username to proxy as. Username string `yaml:"username"` + + // Listener overrides "listen" and directly provides an opened listener to + // use. + Listener net.Listener `yaml:"-"` } func (s *DatabaseTunnelService) Type() string { @@ -66,7 +71,7 @@ func (s *DatabaseTunnelService) UnmarshalYAML(node *yaml.Node) error { func (s *DatabaseTunnelService) CheckAndSetDefaults() error { switch { - case s.Listen == "": + case s.Listen == "" && s.Listener == nil: return trace.BadParameter("listen: should not be empty") case s.Service == "": return trace.BadParameter("service: should not be empty") diff --git a/lib/tbot/service_database_tunnel.go b/lib/tbot/service_database_tunnel.go index 7cd1c7ce67929..89ce0d821cbc1 100644 --- a/lib/tbot/service_database_tunnel.go +++ b/lib/tbot/service_database_tunnel.go @@ -176,21 +176,24 @@ func (s *DatabaseTunnelService) Run(ctx context.Context) error { ctx, span := tracer.Start(ctx, "DatabaseTunnelService/Run") defer span.End() - listenUrl, err := url.Parse(s.cfg.Listen) - if err != nil { - return trace.Wrap(err, "parsing listen url") - } + l := s.cfg.Listener + if l == nil { + listenUrl, err := url.Parse(s.cfg.Listen) + if err != nil { + return trace.Wrap(err, "parsing listen url") + } - s.log.WithField("address", listenUrl.String()).Debug("Opening listener for database tunnel.") - l, err := net.Listen("tcp", listenUrl.Host) - if err != nil { - return trace.Wrap(err, "opening listener") - } - defer func() { - if err := l.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { - s.log.WithError(err).Error("Failed to close listener") + s.log.WithField("address", listenUrl.String()).Debug("Opening listener for database tunnel.") + l, err = net.Listen("tcp", listenUrl.Host) + if err != nil { + return trace.Wrap(err, "opening listener") } - }() + defer func() { + if err := l.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { + s.log.WithError(err).Error("Failed to close listener") + } + }() + } lpCfg, err := s.buildLocalProxyConfig(ctx) if err != nil { diff --git a/lib/tbot/tbot_test.go b/lib/tbot/tbot_test.go index e9b82dd855a98..7ac030b394efd 100644 --- a/lib/tbot/tbot_test.go +++ b/lib/tbot/tbot_test.go @@ -26,6 +26,7 @@ import ( "net" "os" "path" + "sync" "testing" "time" @@ -759,6 +760,12 @@ func TestBotDatabaseTunnel(t *testing.T) { role, err = rootClient.UpsertRole(ctx, role) require.NoError(t, err) + botListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + botListener.Close() + }) + // Prepare the bot config onboarding, _ := testhelpers.MakeBot(t, rootClient, "test", role.GetName()) botConfig := testhelpers.DefaultBotConfig( @@ -770,8 +777,7 @@ func TestBotDatabaseTunnel(t *testing.T) { Insecure: true, ServiceConfigs: []config.ServiceConfig{ &config.DatabaseTunnelService{ - // TODO: Perhaps allow FD or listener to be injected - Listen: "tcp://127.0.0.1:39933", + Listener: botListener, Service: "test-database", Database: "mydb", Username: "llama", @@ -783,16 +789,20 @@ func TestBotDatabaseTunnel(t *testing.T) { b := New(botConfig, log) // Spin up goroutine for bot to run in - botCtx, cancelBot := context.WithCancel(ctx) - botCh := make(chan error, 1) + ctx, cancel := context.WithCancel(ctx) + wg := sync.WaitGroup{} + wg.Add(1) go func() { - botCh <- b.Run(botCtx) + defer wg.Done() + err := b.Run(ctx) + assert.NoError(t, err, "bot should not exit with error") + cancel() }() // We can't predict exactly when the tunnel will be ready so we use // EventuallyWithT to retry. require.EventuallyWithT(t, func(t *assert.CollectT) { - conn, err := pgconn.Connect(ctx, "postgres://127.0.0.1:39933/mydb?user=llama") + conn, err := pgconn.Connect(ctx, fmt.Sprintf("postgres://%s/mydb?user=llama", botListener.Addr().String())) if !assert.NoError(t, err) { return } @@ -801,9 +811,9 @@ func TestBotDatabaseTunnel(t *testing.T) { }() _, err = conn.Exec(ctx, "SELECT 1;").ReadAll() assert.NoError(t, err) - }, 5*time.Second, 100*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) - // Shut down bot and make sure it exits cleanly. - cancelBot() - require.NoError(t, <-botCh) + // Shut down bot and make sure it exits. + cancel() + wg.Wait() }