Skip to content

Commit

Permalink
refactor: use functions from srvconnect to authenticate/connect to db…
Browse files Browse the repository at this point in the history
… srv
  • Loading branch information
gabrielcorado committed Nov 19, 2024
1 parent 6abaf2f commit e46a3bd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 156 deletions.
3 changes: 2 additions & 1 deletion integration/db/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db"
"github.com/gravitational/teleport/lib/srv/db/cassandra"
"github.com/gravitational/teleport/lib/srv/db/common"
dbsrvconnect "github.com/gravitational/teleport/lib/srv/db/common/srvconnect"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/mysql"
"github.com/gravitational/teleport/lib/srv/db/postgres"
Expand Down Expand Up @@ -607,7 +608,7 @@ func TestDatabaseAccessPostgresSeparateListenerTLSDisabled(t *testing.T) {
func init() {
// Override database agents shuffle behavior to ensure they're always
// tried in the same order during tests. Used for HA tests.
db.SetShuffleFunc(db.ShuffleSort)
db.SetShuffleFunc(dbsrvconnect.ShuffleSort)
}

// testHARootCluster verifies that proxy falls back to a healthy
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/clickhouse"
"github.com/gravitational/teleport/lib/srv/db/cloud"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/common/srvconnect"
"github.com/gravitational/teleport/lib/srv/db/dynamodb"
"github.com/gravitational/teleport/lib/srv/db/elasticsearch"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
Expand Down Expand Up @@ -2257,7 +2258,7 @@ func (c *testContext) Close() error {
func init() {
// Override database agents shuffle behavior to ensure they're always
// tried in the same order during tests. Used for HA tests.
SetShuffleFunc(ShuffleSort)
SetShuffleFunc(srvconnect.ShuffleSort)
}

func setupTestContext(ctx context.Context, t testing.TB, withDatabases ...withDatabaseOption) *testContext {
Expand Down
193 changes: 39 additions & 154 deletions lib/srv/db/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,36 @@ package db
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/common/enterprise"
"github.com/gravitational/teleport/lib/srv/db/common/srvconnect"
"github.com/gravitational/teleport/lib/srv/db/dbutils"
"github.com/gravitational/teleport/lib/srv/db/mysql"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/srv/db/sqlserver"
"github.com/gravitational/teleport/lib/srv/ingress"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -105,42 +97,22 @@ type ProxyServerConfig struct {
MySQLServerVersion string
}

// ShuffleFunc defines a function that shuffles a list of database servers.
type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer

// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers.
// Used to provide load balancing behavior when proxying to multiple agents.
func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer {
rand.New(rand.NewSource(time.Now().UnixNano())).Shuffle(
len(servers), func(i, j int) {
servers[i], servers[j] = servers[j], servers[i]
})
return servers
}

// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID.
// Used to provide predictable behavior in tests.
func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer {
sort.Sort(types.DatabaseServers(servers))
return servers
}

var (
// mu protects the shuffleFunc global access.
mu sync.RWMutex
// shuffleFunc provides shuffle behavior for multiple database agents.
shuffleFunc ShuffleFunc = ShuffleRandom
shuffleFunc srvconnect.ShuffleFunc = srvconnect.ShuffleRandom
)

// SetShuffleFunc sets the shuffle behavior when proxying to multiple agents.
func SetShuffleFunc(fn ShuffleFunc) {
func SetShuffleFunc(fn srvconnect.ShuffleFunc) {
mu.Lock()
defer mu.Unlock()
shuffleFunc = fn
}

// getShuffleFunc returns the configured function used to shuffle agents.
func getShuffleFunc() ShuffleFunc {
func getShuffleFunc() srvconnect.ShuffleFunc {
mu.RLock()
defer mu.RUnlock()
return shuffleFunc
Expand Down Expand Up @@ -453,59 +425,38 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext
} else {
labels = getLabelsFromDB(nil)
}

labels["available_db_servers"] = strconv.Itoa(len(proxyCtx.Servers))

defer observeLatency(connectionSetupTime.With(labels))()

var attemptedServers int
var (
serviceConn net.Conn
stats srvconnect.ConnectStats
err error
)

defer func() {
dialAttemptedServers.With(labels).Observe(float64(attemptedServers))
dialAttemptedServers.With(labels).Observe(float64(stats.AttemptedServers))
dialAttempts.With(labels).Add(float64(stats.DialAttempts))
dialFailures.With(labels).Add(float64(stats.DialFailures))
}()

// There may be multiple database servers proxying the same database. If
// we get a connection problem error trying to dial one of them, likely
// the database server is down so try the next one.
for _, server := range getShuffleFunc()(proxyCtx.Servers) {
attemptedServers++
s.log.DebugContext(ctx, "Dialing to database service.", "server", server)
tlsConfig, err := s.getConfigForServer(ctx, proxyCtx.Identity, server)
if err != nil {
return nil, trace.Wrap(err)
}

dialAttempts.With(labels).Inc()
serviceConn, err := proxyCtx.Cluster.Dial(reversetunnelclient.DialParams{
From: clientSrcAddr,
To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode},
OriginalClientDstAddr: clientDstAddr,
ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), proxyCtx.Cluster.GetName()),
ConnType: types.DatabaseTunnel,
ProxyIDs: server.GetProxyIDs(),
})
if err != nil {
dialFailures.With(labels).Inc()
// If an agent is down, we'll retry on the next one (if available).
if isReverseTunnelDownError(err) {
s.log.WarnContext(ctx, "Failed to dial database service.", "server", server, "error", err)
continue
}
return nil, trace.Wrap(err)
}
// Upgrade the connection so the client identity can be passed to the
// remote server during TLS handshake. On the remote side, the connection
// received from the reverse tunnel will be handled by tls.Server.
serviceConn = tls.Client(serviceConn, tlsConfig)
return serviceConn, nil
serviceConn, stats, err = srvconnect.Connect(ctx, srvconnect.ConnectParams{
Logger: s.log,
Identity: proxyCtx.Identity,
Servers: proxyCtx.Servers,
ShuffleFunc: getShuffleFunc(),
Cluster: proxyCtx.Cluster,
Dialer: proxyCtx.Cluster,
CertSigner: s.cfg.AuthClient,
AuthPreference: s.cfg.AuthClient,
ClientSrcAddr: clientSrcAddr,
ClientDstAddr: clientDstAddr,
})
if err != nil {
return nil, trace.Wrap(err)
}
return nil, trace.BadParameter("failed to connect to any of the database servers")
}

// isReverseTunnelDownError returns true if the provided error indicates that
// the reverse tunnel connection is down e.g. because the agent is down.
func isReverseTunnelDownError(err error) bool {
return trace.IsConnectionProblem(err) ||
strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel)
return serviceConn, nil
}

// Proxy starts proxying all traffic received from database client between
Expand Down Expand Up @@ -569,94 +520,28 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para
if params.ClientIP != "" {
identity.LoginIP = params.ClientIP
}
cluster, servers, err := s.getDatabaseServers(ctx, identity)
if err != nil {
return nil, trace.Wrap(err)
}
return &common.ProxyContext{
Identity: identity,
Cluster: cluster,
Servers: servers,
AuthContext: authContext,
}, nil
}

// getDatabaseServers finds database servers that proxy the database instance
// encoded in the provided identity.
func (s *ProxyServer) getDatabaseServers(ctx context.Context, identity tlsca.Identity) (reversetunnelclient.RemoteSite, []types.DatabaseServer, error) {
cluster, err := s.cfg.Tunnel.GetSite(identity.RouteToCluster)
if err != nil {
return nil, nil, trace.Wrap(err)
}
accessPoint, err := cluster.CachingAccessPoint()
if err != nil {
return nil, nil, trace.Wrap(err)
}
servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace)
if err != nil {
return nil, nil, trace.Wrap(err)
}
s.log.DebugContext(ctx, "Available database servers.", "cluster", cluster.GetName(), "servers", servers)
// Find out which database servers proxy the database a user is
// connecting to using routing information from identity.
var result []types.DatabaseServer
for _, server := range servers {
if server.GetDatabase().GetName() == identity.RouteToDatabase.ServiceName {
result = append(result, server)
}
}
if len(result) != 0 {
return cluster, result, nil
}
return nil, nil, trace.NotFound("database %q not found among registered databases in cluster %q",
identity.RouteToDatabase.ServiceName,
identity.RouteToCluster)
}

// getConfigForServer returns TLS config used for establishing connection
// to a remote database server over reverse tunnel.
func (s *ProxyServer) getConfigForServer(ctx context.Context, identity tlsca.Identity, server types.DatabaseServer) (*tls.Config, error) {
defer observeLatency(tlsConfigTime.With(getLabelsFromDB(server.GetDatabase())))()

privateKey, err := cryptosuites.GenerateKey(ctx,
cryptosuites.GetCurrentSuiteFromAuthPreference(s.cfg.AccessPoint),
cryptosuites.ProxyToDatabaseAgent)
if err != nil {
return nil, trace.Wrap(err)
}
subject, err := identity.Subject()
if err != nil {
return nil, trace.Wrap(err)
}
csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey)
accessPoint, err := cluster.CachingAccessPoint()
if err != nil {
return nil, trace.Wrap(err)
}

response, err := s.cfg.AuthClient.SignDatabaseCSR(ctx, &proto.DatabaseCSRRequest{
CSR: csr,
ClusterName: identity.RouteToCluster,
servers, err := srvconnect.GetDatabaseServers(ctx, srvconnect.GetDatabaseServersParams{
Logger: s.log,
Cluster: cluster,
DatabaseServersGetter: accessPoint,
Identity: identity,
})
if err != nil {
return nil, trace.Wrap(err)
}

cert, err := keys.TLSCertificateForSigner(privateKey, response.Cert)
if err != nil {
return nil, trace.Wrap(err)
}
pool := x509.NewCertPool()
for _, caCert := range response.CACerts {
ok := pool.AppendCertsFromPEM(caCert)
if !ok {
return nil, trace.BadParameter("failed to append CA certificate")
}
}

return &tls.Config{
ServerName: server.GetHostname(),
Certificates: []tls.Certificate{cert},
RootCAs: pool,
return &common.ProxyContext{
Identity: identity,
Cluster: cluster,
Servers: servers,
AuthContext: authContext,
}, nil
}

Expand Down

0 comments on commit e46a3bd

Please sign in to comment.