Skip to content

Commit

Permalink
Reduce flakiness of TestBotDatabaseTunnel (#40179)
Browse files Browse the repository at this point in the history
* Reduce flakiness of `TestBotDatabaseTunnel`

* Provide listener in test
  • Loading branch information
strideynet authored Apr 4, 2024
1 parent 0f5f612 commit 3ba6df5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
7 changes: 6 additions & 1 deletion lib/tbot/config/service_database_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package config

import (
"net"
"net/url"

"github.com/gravitational/trace"
Expand All @@ -44,6 +45,10 @@ type DatabaseTunnelService struct {
Database string `yaml:"database"`
// Username is the database username to proxy as.
Username string `yaml:"username"`

// Listener overrides "listen" and directly provides an opened listener to
// use.
Listener net.Listener `yaml:"-"`
}

func (s *DatabaseTunnelService) Type() string {
Expand All @@ -66,7 +71,7 @@ func (s *DatabaseTunnelService) UnmarshalYAML(node *yaml.Node) error {

func (s *DatabaseTunnelService) CheckAndSetDefaults() error {
switch {
case s.Listen == "":
case s.Listen == "" && s.Listener == nil:
return trace.BadParameter("listen: should not be empty")
case s.Service == "":
return trace.BadParameter("service: should not be empty")
Expand Down
29 changes: 16 additions & 13 deletions lib/tbot/service_database_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,21 +176,24 @@ func (s *DatabaseTunnelService) Run(ctx context.Context) error {
ctx, span := tracer.Start(ctx, "DatabaseTunnelService/Run")
defer span.End()

listenUrl, err := url.Parse(s.cfg.Listen)
if err != nil {
return trace.Wrap(err, "parsing listen url")
}
l := s.cfg.Listener
if l == nil {
listenUrl, err := url.Parse(s.cfg.Listen)
if err != nil {
return trace.Wrap(err, "parsing listen url")
}

s.log.WithField("address", listenUrl.String()).Debug("Opening listener for database tunnel.")
l, err := net.Listen("tcp", listenUrl.Host)
if err != nil {
return trace.Wrap(err, "opening listener")
}
defer func() {
if err := l.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) {
s.log.WithError(err).Error("Failed to close listener")
s.log.WithField("address", listenUrl.String()).Debug("Opening listener for database tunnel.")
l, err = net.Listen("tcp", listenUrl.Host)
if err != nil {
return trace.Wrap(err, "opening listener")
}
}()
defer func() {
if err := l.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) {
s.log.WithError(err).Error("Failed to close listener")
}
}()
}

lpCfg, err := s.buildLocalProxyConfig(ctx)
if err != nil {
Expand Down
30 changes: 20 additions & 10 deletions lib/tbot/tbot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net"
"os"
"path"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -759,6 +760,12 @@ func TestBotDatabaseTunnel(t *testing.T) {
role, err = rootClient.UpsertRole(ctx, role)
require.NoError(t, err)

botListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
botListener.Close()
})

// Prepare the bot config
onboarding, _ := testhelpers.MakeBot(t, rootClient, "test", role.GetName())
botConfig := testhelpers.DefaultBotConfig(
Expand All @@ -770,8 +777,7 @@ func TestBotDatabaseTunnel(t *testing.T) {
Insecure: true,
ServiceConfigs: []config.ServiceConfig{
&config.DatabaseTunnelService{
// TODO: Perhaps allow FD or listener to be injected
Listen: "tcp://127.0.0.1:39933",
Listener: botListener,
Service: "test-database",
Database: "mydb",
Username: "llama",
Expand All @@ -783,16 +789,20 @@ func TestBotDatabaseTunnel(t *testing.T) {
b := New(botConfig, log)

// Spin up goroutine for bot to run in
botCtx, cancelBot := context.WithCancel(ctx)
botCh := make(chan error, 1)
ctx, cancel := context.WithCancel(ctx)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
botCh <- b.Run(botCtx)
defer wg.Done()
err := b.Run(ctx)
assert.NoError(t, err, "bot should not exit with error")
cancel()
}()

// We can't predict exactly when the tunnel will be ready so we use
// EventuallyWithT to retry.
require.EventuallyWithT(t, func(t *assert.CollectT) {
conn, err := pgconn.Connect(ctx, "postgres://127.0.0.1:39933/mydb?user=llama")
conn, err := pgconn.Connect(ctx, fmt.Sprintf("postgres://%s/mydb?user=llama", botListener.Addr().String()))
if !assert.NoError(t, err) {
return
}
Expand All @@ -801,9 +811,9 @@ func TestBotDatabaseTunnel(t *testing.T) {
}()
_, err = conn.Exec(ctx, "SELECT 1;").ReadAll()
assert.NoError(t, err)
}, 5*time.Second, 100*time.Millisecond)
}, 10*time.Second, 100*time.Millisecond)

// Shut down bot and make sure it exits cleanly.
cancelBot()
require.NoError(t, <-botCh)
// Shut down bot and make sure it exits.
cancel()
wg.Wait()
}

0 comments on commit 3ba6df5

Please sign in to comment.