From e46a3bd6a7d3044cb16ef0e4cb214cc21d82ab0f Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Tue, 19 Nov 2024 15:28:15 -0300 Subject: [PATCH] refactor: use functions from srvconnect to authenticate/connect to db srv --- integration/db/db_integration_test.go | 3 +- lib/srv/db/access_test.go | 3 +- lib/srv/db/proxyserver.go | 193 ++++++-------------------- 3 files changed, 43 insertions(+), 156 deletions(-) diff --git a/integration/db/db_integration_test.go b/integration/db/db_integration_test.go index ec68c78c0ed96..6aa4202e62c54 100644 --- a/integration/db/db_integration_test.go +++ b/integration/db/db_integration_test.go @@ -44,6 +44,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db" "github.com/gravitational/teleport/lib/srv/db/cassandra" "github.com/gravitational/teleport/lib/srv/db/common" + dbsrvconnect "github.com/gravitational/teleport/lib/srv/db/common/srvconnect" "github.com/gravitational/teleport/lib/srv/db/mongodb" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" @@ -607,7 +608,7 @@ func TestDatabaseAccessPostgresSeparateListenerTLSDisabled(t *testing.T) { func init() { // Override database agents shuffle behavior to ensure they're always // tried in the same order during tests. Used for HA tests. - db.SetShuffleFunc(db.ShuffleSort) + db.SetShuffleFunc(dbsrvconnect.ShuffleSort) } // testHARootCluster verifies that proxy falls back to a healthy diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 37848ea294c8e..c52ec8debc702 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -81,6 +81,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/clickhouse" "github.com/gravitational/teleport/lib/srv/db/cloud" "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/common/srvconnect" "github.com/gravitational/teleport/lib/srv/db/dynamodb" "github.com/gravitational/teleport/lib/srv/db/elasticsearch" "github.com/gravitational/teleport/lib/srv/db/mongodb" @@ -2257,7 +2258,7 @@ func (c *testContext) Close() error { func init() { // Override database agents shuffle behavior to ensure they're always // tried in the same order during tests. Used for HA tests. - SetShuffleFunc(ShuffleSort) + SetShuffleFunc(srvconnect.ShuffleSort) } func setupTestContext(ctx context.Context, t testing.TB, withDatabases ...withDatabaseOption) *testContext { diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index fb72e2accfe0d..003b9caf00a67 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -21,16 +21,12 @@ package db import ( "context" "crypto/tls" - "crypto/x509" "errors" - "fmt" "io" "log/slog" "math/rand" "net" - "sort" "strconv" - "strings" "sync" "time" @@ -38,27 +34,23 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client/proto" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" - "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/enterprise" + "github.com/gravitational/teleport/lib/srv/db/common/srvconnect" "github.com/gravitational/teleport/lib/srv/db/dbutils" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" "github.com/gravitational/teleport/lib/srv/db/sqlserver" "github.com/gravitational/teleport/lib/srv/ingress" - "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -105,42 +97,22 @@ type ProxyServerConfig struct { MySQLServerVersion string } -// ShuffleFunc defines a function that shuffles a list of database servers. -type ShuffleFunc func([]types.DatabaseServer) []types.DatabaseServer - -// ShuffleRandom is a ShuffleFunc that randomizes the order of database servers. -// Used to provide load balancing behavior when proxying to multiple agents. -func ShuffleRandom(servers []types.DatabaseServer) []types.DatabaseServer { - rand.New(rand.NewSource(time.Now().UnixNano())).Shuffle( - len(servers), func(i, j int) { - servers[i], servers[j] = servers[j], servers[i] - }) - return servers -} - -// ShuffleSort is a ShuffleFunc that sorts database servers by name and host ID. -// Used to provide predictable behavior in tests. -func ShuffleSort(servers []types.DatabaseServer) []types.DatabaseServer { - sort.Sort(types.DatabaseServers(servers)) - return servers -} - var ( // mu protects the shuffleFunc global access. mu sync.RWMutex // shuffleFunc provides shuffle behavior for multiple database agents. - shuffleFunc ShuffleFunc = ShuffleRandom + shuffleFunc srvconnect.ShuffleFunc = srvconnect.ShuffleRandom ) // SetShuffleFunc sets the shuffle behavior when proxying to multiple agents. -func SetShuffleFunc(fn ShuffleFunc) { +func SetShuffleFunc(fn srvconnect.ShuffleFunc) { mu.Lock() defer mu.Unlock() shuffleFunc = fn } // getShuffleFunc returns the configured function used to shuffle agents. -func getShuffleFunc() ShuffleFunc { +func getShuffleFunc() srvconnect.ShuffleFunc { mu.RLock() defer mu.RUnlock() return shuffleFunc @@ -453,59 +425,38 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext } else { labels = getLabelsFromDB(nil) } - labels["available_db_servers"] = strconv.Itoa(len(proxyCtx.Servers)) - defer observeLatency(connectionSetupTime.With(labels))() - var attemptedServers int + var ( + serviceConn net.Conn + stats srvconnect.ConnectStats + err error + ) + defer func() { - dialAttemptedServers.With(labels).Observe(float64(attemptedServers)) + dialAttemptedServers.With(labels).Observe(float64(stats.AttemptedServers)) + dialAttempts.With(labels).Add(float64(stats.DialAttempts)) + dialFailures.With(labels).Add(float64(stats.DialFailures)) }() - // There may be multiple database servers proxying the same database. If - // we get a connection problem error trying to dial one of them, likely - // the database server is down so try the next one. - for _, server := range getShuffleFunc()(proxyCtx.Servers) { - attemptedServers++ - s.log.DebugContext(ctx, "Dialing to database service.", "server", server) - tlsConfig, err := s.getConfigForServer(ctx, proxyCtx.Identity, server) - if err != nil { - return nil, trace.Wrap(err) - } - - dialAttempts.With(labels).Inc() - serviceConn, err := proxyCtx.Cluster.Dial(reversetunnelclient.DialParams{ - From: clientSrcAddr, - To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode}, - OriginalClientDstAddr: clientDstAddr, - ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), proxyCtx.Cluster.GetName()), - ConnType: types.DatabaseTunnel, - ProxyIDs: server.GetProxyIDs(), - }) - if err != nil { - dialFailures.With(labels).Inc() - // If an agent is down, we'll retry on the next one (if available). - if isReverseTunnelDownError(err) { - s.log.WarnContext(ctx, "Failed to dial database service.", "server", server, "error", err) - continue - } - return nil, trace.Wrap(err) - } - // Upgrade the connection so the client identity can be passed to the - // remote server during TLS handshake. On the remote side, the connection - // received from the reverse tunnel will be handled by tls.Server. - serviceConn = tls.Client(serviceConn, tlsConfig) - return serviceConn, nil + serviceConn, stats, err = srvconnect.Connect(ctx, srvconnect.ConnectParams{ + Logger: s.log, + Identity: proxyCtx.Identity, + Servers: proxyCtx.Servers, + ShuffleFunc: getShuffleFunc(), + Cluster: proxyCtx.Cluster, + Dialer: proxyCtx.Cluster, + CertSigner: s.cfg.AuthClient, + AuthPreference: s.cfg.AuthClient, + ClientSrcAddr: clientSrcAddr, + ClientDstAddr: clientDstAddr, + }) + if err != nil { + return nil, trace.Wrap(err) } - return nil, trace.BadParameter("failed to connect to any of the database servers") -} -// isReverseTunnelDownError returns true if the provided error indicates that -// the reverse tunnel connection is down e.g. because the agent is down. -func isReverseTunnelDownError(err error) bool { - return trace.IsConnectionProblem(err) || - strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel) + return serviceConn, nil } // Proxy starts proxying all traffic received from database client between @@ -569,94 +520,28 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para if params.ClientIP != "" { identity.LoginIP = params.ClientIP } - cluster, servers, err := s.getDatabaseServers(ctx, identity) - if err != nil { - return nil, trace.Wrap(err) - } - return &common.ProxyContext{ - Identity: identity, - Cluster: cluster, - Servers: servers, - AuthContext: authContext, - }, nil -} - -// getDatabaseServers finds database servers that proxy the database instance -// encoded in the provided identity. -func (s *ProxyServer) getDatabaseServers(ctx context.Context, identity tlsca.Identity) (reversetunnelclient.RemoteSite, []types.DatabaseServer, error) { cluster, err := s.cfg.Tunnel.GetSite(identity.RouteToCluster) - if err != nil { - return nil, nil, trace.Wrap(err) - } - accessPoint, err := cluster.CachingAccessPoint() - if err != nil { - return nil, nil, trace.Wrap(err) - } - servers, err := accessPoint.GetDatabaseServers(ctx, apidefaults.Namespace) - if err != nil { - return nil, nil, trace.Wrap(err) - } - s.log.DebugContext(ctx, "Available database servers.", "cluster", cluster.GetName(), "servers", servers) - // Find out which database servers proxy the database a user is - // connecting to using routing information from identity. - var result []types.DatabaseServer - for _, server := range servers { - if server.GetDatabase().GetName() == identity.RouteToDatabase.ServiceName { - result = append(result, server) - } - } - if len(result) != 0 { - return cluster, result, nil - } - return nil, nil, trace.NotFound("database %q not found among registered databases in cluster %q", - identity.RouteToDatabase.ServiceName, - identity.RouteToCluster) -} - -// getConfigForServer returns TLS config used for establishing connection -// to a remote database server over reverse tunnel. -func (s *ProxyServer) getConfigForServer(ctx context.Context, identity tlsca.Identity, server types.DatabaseServer) (*tls.Config, error) { - defer observeLatency(tlsConfigTime.With(getLabelsFromDB(server.GetDatabase())))() - - privateKey, err := cryptosuites.GenerateKey(ctx, - cryptosuites.GetCurrentSuiteFromAuthPreference(s.cfg.AccessPoint), - cryptosuites.ProxyToDatabaseAgent) - if err != nil { - return nil, trace.Wrap(err) - } - subject, err := identity.Subject() if err != nil { return nil, trace.Wrap(err) } - csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey) + accessPoint, err := cluster.CachingAccessPoint() if err != nil { return nil, trace.Wrap(err) } - - response, err := s.cfg.AuthClient.SignDatabaseCSR(ctx, &proto.DatabaseCSRRequest{ - CSR: csr, - ClusterName: identity.RouteToCluster, + servers, err := srvconnect.GetDatabaseServers(ctx, srvconnect.GetDatabaseServersParams{ + Logger: s.log, + Cluster: cluster, + DatabaseServersGetter: accessPoint, + Identity: identity, }) if err != nil { return nil, trace.Wrap(err) } - - cert, err := keys.TLSCertificateForSigner(privateKey, response.Cert) - if err != nil { - return nil, trace.Wrap(err) - } - pool := x509.NewCertPool() - for _, caCert := range response.CACerts { - ok := pool.AppendCertsFromPEM(caCert) - if !ok { - return nil, trace.BadParameter("failed to append CA certificate") - } - } - - return &tls.Config{ - ServerName: server.GetHostname(), - Certificates: []tls.Certificate{cert}, - RootCAs: pool, + return &common.ProxyContext{ + Identity: identity, + Cluster: cluster, + Servers: servers, + AuthContext: authContext, }, nil }