From 903b5f5cbe63aeb4bcfb5f0c4ea4f555b598d221 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Fri, 13 Dec 2024 23:16:29 -0300 Subject: [PATCH] Add WebSocket handler for WebUI database sessions (#49749) * feat(web): add websocket handler for database webui sessions * refactor: move common structs into a separate package * refactor(web): use ALPN local proxy to dial databases * feat(repl): add default registry * refactor(web): code review suggestions * refactor: update repl config parameters * refactor: move default getter implementation * feat(web): add supports_interactive field on dbs * refactor: code review suggestions * refactor: update database REPL interfaces * chore(web): remove debug print * feat: register postgres repl * refactor(web): update MakeDatabase to receive access checker and interactive * chore(web): remove unused function --- lib/client/alpn.go | 27 +- lib/client/db/postgres/repl/repl.go | 17 +- lib/client/db/postgres/repl/repl_test.go | 6 +- lib/client/db/repl/repl.go | 80 ++++++ lib/defaults/defaults.go | 4 + lib/service/service.go | 9 + lib/service/servicecfg/config.go | 5 + lib/web/apiserver.go | 19 +- lib/web/apiserver_test.go | 10 +- lib/web/databases.go | 335 ++++++++++++++++++++++- lib/web/databases_test.go | 220 +++++++++++++++ lib/web/integrations_awsoidc.go | 7 +- lib/web/kube.go | 4 +- lib/web/servers.go | 34 +-- lib/web/terminal.go | 15 +- lib/web/ui/server.go | 44 ++- lib/web/ui/server_test.go | 33 ++- lib/web/ws_io.go | 12 +- 18 files changed, 793 insertions(+), 88 deletions(-) create mode 100644 lib/client/db/repl/repl.go diff --git a/lib/client/alpn.go b/lib/client/alpn.go index ec4e920923da7..db58eb374701f 100644 --- a/lib/client/alpn.go +++ b/lib/client/alpn.go @@ -85,14 +85,33 @@ type ALPNAuthTunnelConfig struct { // RouteToDatabase contains the destination server that must receive the connection. // Specific for database proxying. RouteToDatabase proto.RouteToDatabase + + // TLSCert specifies the TLS certificate used on the proxy connection. + TLSCert *tls.Certificate +} + +func (c *ALPNAuthTunnelConfig) CheckAndSetDefaults(ctx context.Context) error { + if c.AuthClient == nil { + return trace.BadParameter("missing auth client") + } + + if c.TLSCert == nil { + tlsCert, err := getUserCerts(ctx, c.AuthClient, c.MFAResponse, c.Expires, c.RouteToDatabase, c.ConnectionDiagnosticID) + if err != nil { + return trace.BadParameter("failed to parse private key: %v", err) + } + + c.TLSCert = &tlsCert + } + + return nil } // RunALPNAuthTunnel runs a local authenticated ALPN proxy to another service. // At least one Route (which defines the service) must be defined func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error { - tlsCert, err := getUserCerts(ctx, cfg.AuthClient, cfg.MFAResponse, cfg.Expires, cfg.RouteToDatabase, cfg.ConnectionDiagnosticID) - if err != nil { - return trace.BadParameter("failed to parse private key: %v", err) + if err := cfg.CheckAndSetDefaults(ctx); err != nil { + return trace.Wrap(err) } lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{ @@ -101,7 +120,7 @@ func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error { Protocols: []alpn.Protocol{cfg.Protocol}, Listener: cfg.Listener, ParentContext: ctx, - Cert: tlsCert, + Cert: *cfg.TLSCert, }, alpnproxy.WithALPNConnUpgradeTest(ctx, getClusterCACertPool(cfg.AuthClient))) if err != nil { return trace.Wrap(err) diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index 1f3c6b2cbefc0..514d9160e3efb 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport" clientproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/lib/asciitable" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/defaults" ) @@ -44,13 +45,13 @@ type REPL struct { commands map[string]*command } -func New(client io.ReadWriteCloser, serverConn net.Conn, route clientproto.RouteToDatabase) (*REPL, error) { +func New(_ context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) { config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s", hostnamePlaceholder)) if err != nil { return nil, trace.Wrap(err) } - config.User = route.Username - config.Database = route.Database + config.User = cfg.Route.Username + config.Database = cfg.Route.Database config.ConnectTimeout = defaults.DatabaseConnectTimeout config.RuntimeParams = map[string]string{ applicationNameParamName: applicationNameParamValue, @@ -63,15 +64,15 @@ func New(client io.ReadWriteCloser, serverConn net.Conn, route clientproto.Route return []string{hostnamePlaceholder}, nil } config.DialFunc = func(_ context.Context, _, _ string) (net.Conn, error) { - return serverConn, nil + return cfg.ServerConn, nil } return &REPL{ connConfig: config, - client: client, - serverConn: serverConn, - route: route, - term: term.NewTerminal(client, ""), + client: cfg.Client, + serverConn: cfg.ServerConn, + route: cfg.Route, + term: term.NewTerminal(cfg.Client, ""), commands: initCommands(), }, nil } diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index 1d571f2bdfcc9..0aa03b84c8023 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -33,6 +33,7 @@ import ( clientproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/lib/client/db/postgres/repl/testdata" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/utils/golden" ) @@ -338,7 +339,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( } }(tc) - r, err := New(tc.clientConn, tc.serverConn, tc.route) + instance, err := New(ctx, &dbrepl.NewREPLConfig{Client: tc.clientConn, ServerConn: tc.serverConn, Route: tc.route}) require.NoError(t, err) if !cfg.skipREPLRun { @@ -347,7 +348,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( runCtx, cancelRun := context.WithCancel(ctx) runErrChan := make(chan error, 1) go func() { - runErrChan <- r.Run(runCtx) + runErrChan <- instance.Run(runCtx) }() t.Cleanup(func() { cancelRun() @@ -363,6 +364,7 @@ func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) ( }) } + r, _ := instance.(*REPL) return r, tc } diff --git a/lib/client/db/repl/repl.go b/lib/client/db/repl/repl.go new file mode 100644 index 0000000000000..abfed3dd5b8b6 --- /dev/null +++ b/lib/client/db/repl/repl.go @@ -0,0 +1,80 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "io" + "net" + + "github.com/gravitational/trace" + + clientproto "github.com/gravitational/teleport/api/client/proto" +) + +// NewREPLConfig represents the database REPL constructor config. +type NewREPLConfig struct { + // Client is the user terminal client. + Client io.ReadWriteCloser + // ServerConn is the database server connection. + ServerConn net.Conn + // Route is the session routing information. + Route clientproto.RouteToDatabase +} + +// REPLNewFunc defines the constructor function for database REPL +// sessions. +type REPLNewFunc func(context.Context, *NewREPLConfig) (REPLInstance, error) + +// REPLInstance represents a REPL instance. +type REPLInstance interface { + // Run executes the REPL. This is a blocking operation. + Run(context.Context) error +} + +// REPLRegistry is an interface for initializing REPL instances and checking +// if the database protocol is supported. +type REPLRegistry interface { + // IsSupported returns if a database protocol is supported by any REPL. + IsSupported(protocol string) bool + // NewInstance initializes a new REPL instance given the configuration. + NewInstance(context.Context, *NewREPLConfig) (REPLInstance, error) +} + +// NewREPLGetter creates a new REPL getter given the list of supported REPLs. +func NewREPLGetter(replNewFuncs map[string]REPLNewFunc) REPLRegistry { + return &replRegistry{m: replNewFuncs} +} + +type replRegistry struct { + m map[string]REPLNewFunc +} + +// IsSupported implements REPLGetter. +func (r *replRegistry) IsSupported(protocol string) bool { + _, supported := r.m[protocol] + return supported +} + +// NewInstance implements REPLGetter. +func (r *replRegistry) NewInstance(ctx context.Context, cfg *NewREPLConfig) (REPLInstance, error) { + if newFunc, ok := r.m[cfg.Route.Protocol]; ok { + return newFunc(ctx, cfg) + } + + return nil, trace.NotImplemented("REPL not supported for protocol %q", cfg.Route.Protocol) +} diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 7a7290631cb07..2bfd34c557d69 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -721,6 +721,10 @@ const ( // WebsocketKubeExec provides latency information for a session. WebsocketKubeExec = "k" + + // WebsocketDatabaseSessionRequest is received when a new database session + // is requested. + WebsocketDatabaseSessionRequest = "d" ) // The following are cryptographic primitives Teleport does not support in diff --git a/lib/service/service.go b/lib/service/service.go index e8c344ddd9b50..716ac39c8a97b 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -108,6 +108,8 @@ import ( _ "github.com/gravitational/teleport/lib/backend/pgbk" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/cache" + pgrepl "github.com/gravitational/teleport/lib/client/db/postgres/repl" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/cloud/imds" @@ -1085,6 +1087,12 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { cfg.PluginRegistry = plugin.NewRegistry() } + if cfg.DatabaseREPLRegistry == nil { + cfg.DatabaseREPLRegistry = dbrepl.NewREPLGetter(map[string]dbrepl.REPLNewFunc{ + defaults.ProtocolPostgres: pgrepl.New, + }) + } + var cloudLabels labels.Importer // Check if we're on a cloud instance, and if we should override the node's hostname. @@ -4672,6 +4680,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels, IntegrationAppHandler: connectionsHandler, FeatureWatchInterval: retryutils.HalfJitter(web.DefaultFeatureWatchInterval * 2), + DatabaseREPLRegistry: cfg.DatabaseREPLRegistry, } webHandler, err := web.NewHandler(webConfig) if err != nil { diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go index d3dd237da74cd..6a14f1ceba5d0 100644 --- a/lib/service/servicecfg/config.go +++ b/lib/service/servicecfg/config.go @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/lite" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cloud/imds" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -265,6 +266,10 @@ type Config struct { // AccessGraph represents AccessGraph server config AccessGraph AccessGraphConfig + // DatabaseREPLRegistry is used to retrieve datatabase REPL given the + // protocol. + DatabaseREPLRegistry dbrepl.REPLRegistry + // token is either the token needed to join the auth server, or a path pointing to a file // that contains the token // diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 796e8351ca6f7..6c5f952290536 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -80,6 +80,7 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/client" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/client/sso" "github.com/gravitational/teleport/lib/defaults" dtconfig "github.com/gravitational/teleport/lib/devicetrust/config" @@ -315,6 +316,9 @@ type Config struct { // FeatureWatchInterval is the interval between pings to the auth server // to fetch new cluster features FeatureWatchInterval time.Duration + + // DatabaseREPLRegistry is used for retrieving database REPL. + DatabaseREPLRegistry dbrepl.REPLRegistry } // SetDefaults ensures proper default values are set if @@ -834,6 +838,7 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions h.GET("/webapi/sites/:site/kube/exec/ws", h.WithClusterAuthWebSocket(h.podConnect)) // connect to a pod with exec (via websocket, with auth over websocket) + h.GET("/webapi/sites/:site/db/exec/ws", h.WithClusterAuthWebSocket(h.dbConnect)) // Audit events handlers. h.GET("/webapi/sites/:site/events/search", h.WithClusterAuth(h.clusterSearchEvents)) // search site events @@ -3073,9 +3078,6 @@ func (h *Handler) clusterUnifiedResourcesGet(w http.ResponseWriter, request *htt getUserGroupLookup := h.getUserGroupLookup(request.Context(), clt) - var dbNames, dbUsers []string - hasFetchedDBUsersAndNames := false - unifiedResources := make([]any, 0, len(page)) for _, enriched := range page { switch r := enriched.ResourceWithLabels.(type) { @@ -3087,14 +3089,7 @@ func (h *Handler) clusterUnifiedResourcesGet(w http.ResponseWriter, request *htt unifiedResources = append(unifiedResources, ui.MakeServer(site.GetName(), r, logins, enriched.RequiresRequest)) case types.DatabaseServer: - if !hasFetchedDBUsersAndNames { - dbNames, dbUsers, err = getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - hasFetchedDBUsersAndNames = true - } - db := ui.MakeDatabase(r.GetDatabase(), dbUsers, dbNames, enriched.RequiresRequest) + db := ui.MakeDatabase(r.GetDatabase(), accessChecker, h.cfg.DatabaseREPLRegistry, enriched.RequiresRequest) unifiedResources = append(unifiedResources, db) case types.AppServer: allowedAWSRoles, err := calculateAppLogins(accessChecker, r, enriched.Logins) @@ -3588,6 +3583,7 @@ func (h *Handler) siteNodeConnect( } term, err := NewTerminal(ctx, TerminalHandlerConfig{ + Logger: h.logger, Term: req.Term, SessionCtx: sessionCtx, UserAuthClient: clt, @@ -3740,6 +3736,7 @@ func (h *Handler) podConnect( ws: ws, keepAliveInterval: keepAliveInterval, log: h.log.WithField(teleport.ComponentKey, "pod"), + logger: h.logger.With(teleport.ComponentKey, "pod"), userClient: clt, localCA: hostCA, configServerAddr: serverAddr, diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 08611d1c2390a..72f5b4171f806 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -111,6 +111,7 @@ import ( "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/conntest" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -210,6 +211,9 @@ type webSuiteConfig struct { // clock to use for all server components clock clockwork.FakeClock + + // databaseREPLGetter allows setting custom database REPLs. + databaseREPLGetter dbrepl.REPLRegistry } func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { @@ -508,6 +512,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { return &proxyClientCert, nil }, IntegrationAppHandler: &mockIntegrationAppHandler{}, + DatabaseREPLRegistry: cfg.databaseREPLGetter, } if handlerConfig.HealthCheckAppServer == nil { @@ -7436,6 +7441,7 @@ func TestOverwriteDatabase(t *testing.T) { env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "user", nil /* roles */) + accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, env.server.ClusterName(), nil) initDb, err := types.NewDatabaseV3(types.Metadata{ Name: "postgres", @@ -7476,7 +7482,8 @@ func TestOverwriteDatabase(t *testing.T) { backendDb, err := env.server.Auth().GetDatabase(context.Background(), req.Name) require.NoError(t, err) - require.Equal(t, webui.MakeDatabase(backendDb, nil, nil, false), gotDb) + + require.Equal(t, webui.MakeDatabase(backendDb, accessChecker, proxy.handler.handler.cfg.DatabaseREPLRegistry, false), gotDb) }, }, { @@ -8388,6 +8395,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula return &proxyClientCert, nil }, IntegrationAppHandler: &mockIntegrationAppHandler{}, + DatabaseREPLRegistry: &mockDatabaseREPLRegistry{repl: map[string]dbrepl.REPLNewFunc{}}, }, SetClock(clock)) require.NoError(t, err) diff --git a/lib/web/databases.go b/lib/web/databases.go index 557385d90f84c..f84beeb4ef7a8 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -21,28 +21,46 @@ package web import ( "context" "crypto/sha1" + "crypto/tls" "encoding/base64" "encoding/json" "encoding/pem" + "errors" "fmt" + "io" + "log/slog" "net" "net/http" "net/url" + "time" + gogoproto "github.com/gogo/protobuf/proto" + "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" + oteltrace "go.opentelemetry.io/otel/trace" "github.com/gravitational/teleport/api/client/proto" + clientproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/tlsutils" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/client" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" + alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" "github.com/gravitational/teleport/lib/ui" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/listener" "github.com/gravitational/teleport/lib/web/scripts" + "github.com/gravitational/teleport/lib/web/terminal" webui "github.com/gravitational/teleport/lib/web/ui" ) @@ -141,12 +159,8 @@ func (h *Handler) handleDatabaseCreateOrOverwrite(w http.ResponseWriter, r *http if err != nil { return nil, trace.Wrap(err) } - dbNames, dbUsers, err := getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - return webui.MakeDatabase(database, dbUsers, dbNames, false /* requiresRequest */), nil + return webui.MakeDatabase(database, accessChecker, h.cfg.DatabaseREPLRegistry, false /* requiresRequest */), nil } // updateDatabaseRequest contains some updatable fields of a database resource. @@ -254,7 +268,12 @@ func (h *Handler) handleDatabasePartialUpdate(w http.ResponseWriter, r *http.Req return nil, trace.Wrap(err) } - return webui.MakeDatabase(database, nil /* dbUsers */, nil /* dbNames */, false /* requiresRequest */), nil + accessChecker, err := sctx.GetUserAccessChecker() + if err != nil { + return nil, trace.Wrap(err) + } + + return webui.MakeDatabase(database, accessChecker, h.cfg.DatabaseREPLRegistry, false /* requiresRequest */), nil } // databaseIAMPolicyResponse is the response type for handleDatabaseGetIAMPolicy. @@ -387,6 +406,310 @@ func (h *Handler) sqlServerConfigureADScriptHandle(w http.ResponseWriter, r *htt return nil, trace.Wrap(err) } +func (h *Handler) dbConnect( + w http.ResponseWriter, + r *http.Request, + p httprouter.Params, + sctx *SessionContext, + site reversetunnelclient.RemoteSite, + ws *websocket.Conn, +) (interface{}, error) { + // Create a context for signaling when the terminal session is over and + // link it first with the trace context from the request context + tctx := oteltrace.ContextWithRemoteSpanContext(context.Background(), oteltrace.SpanContextFromContext(r.Context())) + ctx, cancel := context.WithCancel(tctx) + defer cancel() + h.logger.DebugContext(ctx, "Received database interactive connection") + + req, err := readDatabaseSessionRequest(ws) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || terminal.IsOKWebsocketCloseError(trace.Unwrap(err)) { + h.logger.DebugContext(ctx, "Database interactive session closed before receiving request") + return nil, nil + } + + var netError net.Error + if errors.As(trace.Unwrap(err), &netError) && netError.Timeout() { + return nil, trace.BadParameter("timed out waiting for database connect request data on websocket connection") + } + + return nil, trace.Wrap(err) + } + + log := h.logger.With( + "protocol", req.Protocol, + "service_name", req.ServiceName, + "database_name", req.DatabaseName, + "database_user", req.DatabaseUser, + "database_roles", req.DatabaseRoles, + ) + log.DebugContext(ctx, "Received database interactive session request") + + if !h.cfg.DatabaseREPLRegistry.IsSupported(req.Protocol) { + log.ErrorContext(ctx, "Unsupported database protocol") + return nil, trace.NotImplemented("%q database protocol not supported for REPL sessions", req.Protocol) + } + + accessPoint, err := site.CachingAccessPoint() + if err != nil { + return nil, trace.Wrap(err) + } + + netConfig, err := accessPoint.GetClusterNetworkingConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + clt, err := sctx.GetUserClient(ctx, site) + if err != nil { + return nil, trace.Wrap(err) + } + + stream := terminal.NewStream(ctx, terminal.StreamConfig{WS: ws}) + defer stream.Close() + + replConn, alpnConn := net.Pipe() + sess := &databaseInteractiveSession{ + ctx: ctx, + log: log, + req: req, + stream: stream, + ws: ws, + sctx: sctx, + site: site, + clt: clt, + replConn: replConn, + alpnConn: alpnConn, + keepAliveInterval: netConfig.GetKeepAliveInterval(), + registry: h.cfg.DatabaseREPLRegistry, + proxyAddr: h.PublicProxyAddr(), + } + defer sess.Close() + + if err := sess.Run(); err != nil { + log.ErrorContext(ctx, "Database interactive session exited with error", "error", err) + return nil, trace.Wrap(err) + } + + return nil, nil +} + +// DatabaseSessionRequest describes a request to create a web-based terminal +// database session. +type DatabaseSessionRequest struct { + // ServiceName is the database resource ID the user will be connected. + ServiceName string `json:"serviceName"` + // Protocol is the database protocol. + Protocol string `json:"protocol"` + // DatabaseName is the database name the session will use. + DatabaseName string `json:"dbName"` + // DatabaseUser is the database user used on the session. + DatabaseUser string `json:"dbUser"` + // DatabaseRoles are ratabase roles that will be attached to the user when + // connecting to the database. + DatabaseRoles []string `json:"dbRoles"` +} + +// databaseConnectionRequestWaitTimeout defines how long the server will wait +// for the user to send the connection request. +const databaseConnectionRequestWaitTimeout = defaults.HeadlessLoginTimeout + +// readDatabaseSessionRequest reads the database session requestion message from +// websocket connection. +func readDatabaseSessionRequest(ws *websocket.Conn) (*DatabaseSessionRequest, error) { + err := ws.SetReadDeadline(time.Now().Add(databaseConnectionRequestWaitTimeout)) + if err != nil { + return nil, trace.Wrap(err, "failed to set read deadline for websocket connection") + } + + messageType, bytes, err := ws.ReadMessage() + if err != nil { + return nil, trace.Wrap(err) + } + + if err := ws.SetReadDeadline(time.Time{}); err != nil { + return nil, trace.Wrap(err, "failed to set read deadline for websocket connection") + } + + if messageType != websocket.BinaryMessage { + return nil, trace.BadParameter("expected binary message of type websocket.BinaryMessage, got %v", messageType) + } + + var envelope terminal.Envelope + if err := gogoproto.Unmarshal(bytes, &envelope); err != nil { + return nil, trace.BadParameter("failed to parse envelope: %v", err) + } + + if envelope.Type != defaults.WebsocketDatabaseSessionRequest { + return nil, trace.BadParameter("expected database session request but got %q", envelope.Type) + } + + var req DatabaseSessionRequest + if err := json.Unmarshal([]byte(envelope.Payload), &req); err != nil { + return nil, trace.Wrap(err) + } + + return &req, nil +} + +type databaseInteractiveSession struct { + ctx context.Context + ws *websocket.Conn + stream *terminal.Stream + log *slog.Logger + req *DatabaseSessionRequest + sctx *SessionContext + site reversetunnelclient.RemoteSite + clt authclient.ClientI + replConn net.Conn + alpnConn net.Conn + keepAliveInterval time.Duration + registry dbrepl.REPLRegistry + proxyAddr string +} + +func (s *databaseInteractiveSession) Run() error { + tlsCert, route, err := s.issueCerts() + if err != nil { + return trace.Wrap(err) + } + + if err := s.sendSessionMetadata(); err != nil { + return trace.Wrap(err) + } + + alpnProtocol, err := alpncommon.ToALPNProtocol(route.Protocol) + if err != nil { + return trace.Wrap(err) + } + + go startWSPingLoop(s.ctx, s.ws, s.keepAliveInterval, s.log, s.Close) + + err = client.RunALPNAuthTunnel(s.ctx, client.ALPNAuthTunnelConfig{ + AuthClient: s.clt, + Listener: listener.NewSingleUseListener(s.alpnConn), + Protocol: alpnProtocol, + PublicProxyAddr: s.proxyAddr, + RouteToDatabase: *route, + TLSCert: tlsCert, + }) + if err != nil { + return trace.Wrap(err) + } + + repl, err := s.registry.NewInstance(s.ctx, &dbrepl.NewREPLConfig{ + Client: s.stream, + ServerConn: s.replConn, + Route: *route, + }) + if err != nil { + return trace.Wrap(err) + } + + s.log.DebugContext(s.ctx, "Starting database interactive session") + if err := repl.Run(s.ctx); err != nil { + return trace.Wrap(err) + } + + s.log.DebugContext(s.ctx, "Database interactive session exited with success") + return nil +} + +func (s *databaseInteractiveSession) Close() error { + s.replConn.Close() + return s.ws.Close() +} + +// issueCerts performs the MFA (if required) and generate the user session +// certificates. +func (s *databaseInteractiveSession) issueCerts() (*tls.Certificate, *clientproto.RouteToDatabase, error) { + pk, err := keys.ParsePrivateKey(s.sctx.cfg.Session.GetTLSPriv()) + if err != nil { + return nil, nil, trace.Wrap(err, "failed getting user private key from the session") + } + + publicKeyPEM, err := keys.MarshalPublicKey(pk.Public()) + if err != nil { + return nil, nil, trace.Wrap(err, "failed to marshal public key") + } + + routeToDatabase := clientproto.RouteToDatabase{ + Protocol: s.req.Protocol, + ServiceName: s.req.ServiceName, + Username: s.req.DatabaseUser, + Database: s.req.DatabaseName, + Roles: s.req.DatabaseRoles, + } + + certsReq := clientproto.UserCertsRequest{ + TLSPublicKey: publicKeyPEM, + Username: s.sctx.GetUser(), + Expires: s.sctx.cfg.Session.GetExpiryTime(), + Format: constants.CertificateFormatStandard, + RouteToCluster: s.site.GetName(), + Usage: clientproto.UserCertsRequest_Database, + RouteToDatabase: routeToDatabase, + } + + _, certs, err := client.PerformSessionMFACeremony(s.ctx, client.PerformSessionMFACeremonyParams{ + CurrentAuthClient: s.clt, + RootAuthClient: s.sctx.cfg.RootClient, + MFACeremony: newMFACeremony(s.stream.WSStream, s.sctx.cfg.RootClient.CreateAuthenticateChallenge), + MFAAgainstRoot: s.sctx.cfg.RootClusterName == s.site.GetName(), + MFARequiredReq: &clientproto.IsMFARequiredRequest{ + Target: &clientproto.IsMFARequiredRequest_Database{Database: &routeToDatabase}, + }, + CertsReq: &certsReq, + }) + if err != nil && !errors.Is(err, services.ErrSessionMFANotRequired) { + return nil, nil, trace.Wrap(err, "failed performing mfa ceremony") + } + + if certs == nil { + certs, err = s.sctx.cfg.RootClient.GenerateUserCerts(s.ctx, certsReq) + if err != nil { + return nil, nil, trace.Wrap(err, "failed issuing user certs") + } + } + + tlsCert, err := pk.TLSCertificate(certs.TLS) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return &tlsCert, &routeToDatabase, nil +} + +func (s *databaseInteractiveSession) sendSessionMetadata() error { + sessionMetadataResponse, err := json.Marshal(siteSessionGenerateResponse{Session: session.Session{ + // TODO(gabrielcorado): Have a consistent Session ID. Right now, the + // initial session ID returned won't be correct as the session is only + // initialized by the database server after the REPL starts. + ClusterName: s.site.GetName(), + }}) + if err != nil { + return trace.Wrap(err) + } + + envelope := &terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketSessionMetadata, + Payload: string(sessionMetadataResponse), + } + + envelopeBytes, err := gogoproto.Marshal(envelope) + if err != nil { + return trace.Wrap(err) + } + + err = s.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + // fetchDatabaseWithName fetch a database with provided database name. func fetchDatabaseWithName(ctx context.Context, clt resourcesAPIGetter, r *http.Request, databaseName string) (types.Database, error) { resp, err := clt.ListResources(ctx, proto.ListResourcesRequest{ diff --git a/lib/web/databases_test.go b/lib/web/databases_test.go index 3a0dd78d212d0..4a25b13557d68 100644 --- a/lib/web/databases_test.go +++ b/lib/web/databases_test.go @@ -20,25 +20,35 @@ package web import ( "context" + "crypto/tls" "encoding/json" "fmt" "net/http" "net/url" "regexp" + "sync" "testing" "time" + "github.com/gogo/protobuf/proto" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/stretchr/testify/require" + authproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" + "github.com/gravitational/teleport/lib/client" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" awslib "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/web/terminal" "github.com/gravitational/teleport/lib/web/ui" ) @@ -524,6 +534,216 @@ func TestHandleSQLServerConfigureScriptDatabaseURIEscaped(t *testing.T) { } } +func TestConnectDatabaseInteractiveSession(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + databaseProtocol := defaults.ProtocolPostgres + + // Use a mock REPL and modify it adding the additional configuration when + // it is set. + repl := &mockDatabaseREPL{message: "hello from repl"} + + s := newWebSuiteWithConfig(t, webSuiteConfig{ + disableDiskBasedRecording: true, + authPreferenceSpec: &types.AuthPreferenceSpecV2{ + Type: constants.Local, + ConnectorName: constants.PasswordlessConnector, + SecondFactor: constants.SecondFactorOn, + RequireMFAType: types.RequireMFAType_SESSION, + Webauthn: &types.Webauthn{ + RPID: "localhost", + }, + }, + databaseREPLGetter: &mockDatabaseREPLRegistry{ + repl: map[string]dbrepl.REPLNewFunc{ + databaseProtocol: func(ctx context.Context, c *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) { + repl.setConfig(c) + return repl, nil + }, + }, + }, + }) + s.webHandler.handler.cfg.PublicProxyAddr = s.webHandler.handler.cfg.ProxyWebAddr.String() + + accessRole, err := types.NewRole("access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + DatabaseNames: []string{types.Wildcard}, + DatabaseUsers: []string{types.Wildcard}, + }, + }) + require.NoError(t, err) + pack := s.authPackWithMFA(t, "user", accessRole) + + databaseName := "db" + selfHosted, err := types.NewDatabaseV3(types.Metadata{ + Name: databaseName, + }, types.DatabaseSpecV3{ + Protocol: databaseProtocol, + URI: "localhost:12345", + }) + require.NoError(t, err) + + _, err = s.server.Auth().UpsertDatabaseServer(ctx, mustCreateDatabaseServer(t, selfHosted)) + require.NoError(t, err) + + u := url.URL{ + Host: s.webServer.Listener.Addr().String(), + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()), + } + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + ws, resp, err := dialer.DialContext(ctx, u.String(), header) + require.NoError(t, err) + defer ws.Close() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token)) + + req := DatabaseSessionRequest{ + Protocol: databaseProtocol, + ServiceName: databaseName, + DatabaseName: "postgres", + DatabaseUser: "postgres", + DatabaseRoles: []string{"reader"}, + } + encodedReq, err := json.Marshal(req) + require.NoError(t, err) + reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketDatabaseSessionRequest, + Payload: string(encodedReq), + }) + require.NoError(t, err) + require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage)) + + performMFACeremonyWS(t, ws, pack) + + // After the MFA is performed we expect the WebSocket to receive the + // session data information. + sessionData := receiveWSMessage(t, ws) + require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type) + + // Assert data written by the REPL comes as raw data. + replResp := receiveWSMessage(t, ws) + require.Equal(t, defaults.WebsocketRaw, replResp.Type) + require.Equal(t, repl.message, replResp.Payload) + + require.NoError(t, ws.Close()) + require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed") +} + +func receiveWSMessage(t *testing.T, ws *websocket.Conn) terminal.Envelope { + t.Helper() + + typ, raw, err := ws.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.BinaryMessage, typ) + var env terminal.Envelope + require.NoError(t, proto.Unmarshal(raw, &env)) + return env +} + +func performMFACeremonyWS(t *testing.T, ws *websocket.Conn, pack *authPack) { + t.Helper() + + ty, raw, err := ws.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.BinaryMessage, ty, "got unexpected websocket message type %d", ty) + + var env terminal.Envelope + require.NoError(t, proto.Unmarshal(raw, &env)) + + var challenge client.MFAAuthenticateChallenge + require.NoError(t, json.Unmarshal([]byte(env.Payload), &challenge)) + + res, err := pack.device.SolveAuthn(&authproto.MFAAuthenticateChallenge{ + WebauthnChallenge: wantypes.CredentialAssertionToProto(challenge.WebauthnChallenge), + }) + require.NoError(t, err) + + webauthnResBytes, err := json.Marshal(wantypes.CredentialAssertionResponseFromProto(res.GetWebauthn())) + require.NoError(t, err) + + envelopeBytes, err := proto.Marshal(&terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketMFAChallenge, + Payload: string(webauthnResBytes), + }) + require.NoError(t, err) + require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) +} + +type mockDatabaseREPLRegistry struct { + repl map[string]dbrepl.REPLNewFunc +} + +// NewInstance implements repl.REPLGetter. +func (m *mockDatabaseREPLRegistry) NewInstance(ctx context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) { + if replFunc, ok := m.repl[cfg.Route.Protocol]; ok { + return replFunc(ctx, cfg) + } + + return nil, trace.NotImplemented("not supported") +} + +// IsSupported implements repl.REPLGetter. +func (m *mockDatabaseREPLRegistry) IsSupported(protocol string) bool { + _, supported := m.repl[protocol] + return supported +} + +type mockDatabaseREPL struct { + mu sync.Mutex + message string + cfg *dbrepl.NewREPLConfig + closed bool +} + +func (m *mockDatabaseREPL) Run(_ context.Context) error { + m.mu.Lock() + defer func() { + m.closeUnlocked() + m.mu.Unlock() + }() + + if _, err := m.cfg.Client.Write([]byte(m.message)); err != nil { + return trace.Wrap(err) + } + + if _, err := m.cfg.ServerConn.Write([]byte("Hello")); err != nil { + return trace.Wrap(err) + } + + return nil +} + +func (m *mockDatabaseREPL) setConfig(c *dbrepl.NewREPLConfig) { + m.mu.Lock() + defer m.mu.Unlock() + m.cfg = c +} + +func (m *mockDatabaseREPL) getClosed() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.closed +} + +func (m *mockDatabaseREPL) closeUnlocked() { + m.closed = true +} + func mustCreateDatabaseServer(t *testing.T, db *types.DatabaseV3) types.DatabaseServer { t.Helper() diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index a869d7a113f91..0597faa8e5425 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -87,9 +87,14 @@ func (h *Handler) awsOIDCListDatabases(w http.ResponseWriter, r *http.Request, p return nil, trace.Wrap(err) } + accessChecker, err := sctx.GetUserAccessChecker() + if err != nil { + return nil, trace.Wrap(err) + } + return ui.AWSOIDCListDatabasesResponse{ NextToken: listDatabasesResp.NextToken, - Databases: ui.MakeDatabases(listDatabasesResp.Databases, nil, nil), + Databases: ui.MakeDatabases(listDatabasesResp.Databases, accessChecker, h.cfg.DatabaseREPLRegistry), }, nil } diff --git a/lib/web/kube.go b/lib/web/kube.go index aad3a0a25c817..ccfd76380103f 100644 --- a/lib/web/kube.go +++ b/lib/web/kube.go @@ -23,6 +23,7 @@ import ( "context" "encoding/json" "errors" + "log/slog" "net/http" "strings" "sync/atomic" @@ -62,6 +63,7 @@ type podHandler struct { ws *websocket.Conn keepAliveInterval time.Duration log *logrus.Entry + logger *slog.Logger userClient authclient.ClientI localCA types.CertAuthority @@ -207,7 +209,7 @@ func (p *podHandler) handler(r *http.Request) error { }) // Start sending ping frames through websocket to the client. - go startWSPingLoop(r.Context(), p.ws, p.keepAliveInterval, p.log, p.Close) + go startWSPingLoop(r.Context(), p.ws, p.keepAliveInterval, p.logger, p.Close) pk, err := keys.ParsePrivateKey(p.sctx.cfg.Session.GetTLSPriv()) if err != nil { diff --git a/lib/web/servers.go b/lib/web/servers.go index af76b4d924620..8babf48dbb1bf 100644 --- a/lib/web/servers.go +++ b/lib/web/servers.go @@ -29,7 +29,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/ui" webui "github.com/gravitational/teleport/lib/web/ui" ) @@ -125,13 +124,8 @@ func (h *Handler) clusterDatabasesGet(w http.ResponseWriter, r *http.Request, p return nil, trace.Wrap(err) } - dbNames, dbUsers, err := getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - return listResourcesGetResponse{ - Items: webui.MakeDatabases(databases, dbUsers, dbNames), + Items: webui.MakeDatabases(databases, accessChecker, h.cfg.DatabaseREPLRegistry), StartKey: page.NextKey, TotalCount: page.Total, }, nil @@ -159,12 +153,7 @@ func (h *Handler) clusterDatabaseGet(w http.ResponseWriter, r *http.Request, p h return nil, trace.Wrap(err) } - dbNames, dbUsers, err := getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - - return webui.MakeDatabase(database, dbUsers, dbNames, false /* requiresRequest */), nil + return webui.MakeDatabase(database, accessChecker, h.cfg.DatabaseREPLRegistry, false /* requiresRequest */), nil } // clusterDatabaseServicesList returns a list of DatabaseServices (database agents) in a form the UI can present. @@ -333,25 +322,6 @@ func (h *Handler) desktopIsActive(w http.ResponseWriter, r *http.Request, p http return desktopIsActive{false}, nil } -func getDatabaseUsersAndNames(accessChecker services.AccessChecker) (dbNames []string, dbUsers []string, err error) { - dbNames, dbUsers, err = accessChecker.CheckDatabaseNamesAndUsers(0, true /* force ttl override*/) - if err != nil { - // if NotFound error: - // This user cannot request database access, has no assigned database names or users - // - // Every other error should be reported upstream. - if !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) - } - - // We proceed with an empty list of DBUsers and DBNames - dbUsers = []string{} - dbNames = []string{} - } - - return dbNames, dbUsers, nil -} - type desktopIsActive struct { Active bool `json:"active"` } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 072e51ecf1140..9326140447eac 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" "net/url" @@ -126,6 +127,10 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl teleport.ComponentKey: teleport.ComponentWebsocket, "session_id": cfg.SessionData.ID.String(), }), + logger: cfg.Logger.With( + teleport.ComponentKey, teleport.ComponentWebsocket, + "session_id", cfg.SessionData.ID.String(), + ), ctx: cfg.SessionCtx, userAuthClient: cfg.UserAuthClient, localAccessPoint: cfg.LocalAccessPoint, @@ -152,6 +157,8 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl // TerminalHandlerConfig contains the configuration options necessary to // correctly set up the TerminalHandler type TerminalHandlerConfig struct { + // Logger specifies the logger. + Logger *slog.Logger // Term is the initial PTY size. Term session.TerminalParams // SessionCtx is the context for the users web session. @@ -205,6 +212,10 @@ type TerminalHandlerConfig struct { } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { + if t.Logger == nil { + t.Logger = slog.Default().With(teleport.ComponentKey, teleport.ComponentWebsocket) + } + // Make sure whatever session is requested is a valid session id. if !t.SessionData.ID.IsZero() { _, err := session.ParseID(t.SessionData.ID.String()) @@ -259,6 +270,8 @@ func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { type sshBaseHandler struct { // log holds the structured logger. log *logrus.Entry + // logger holds the structured logger. + logger *slog.Logger // ctx is a web session context for the currently logged-in user. ctx *SessionContext // userAuthClient is used to fetch nodes and sessions from the backend via the users' identity. @@ -470,7 +483,7 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { }) // Start sending ping frames through websocket to client. - go startWSPingLoop(ctx, ws, t.keepAliveInterval, t.log, t.Close) + go startWSPingLoop(ctx, ws, t.keepAliveInterval, t.logger, t.Close) // Pump raw terminal in/out and audit events into the websocket. go t.streamEvents(ctx, tc) diff --git a/lib/web/ui/server.go b/lib/web/ui/server.go index 9921307c48134..93c3781aa6172 100644 --- a/lib/web/ui/server.go +++ b/lib/web/ui/server.go @@ -306,6 +306,9 @@ type Database struct { AWS *AWS `json:"aws,omitempty"` // RequireRequest indicates if a returned resource is only accessible after an access request RequiresRequest bool `json:"requiresRequest,omitempty"` + // SupportsInteractive is a flag to indicate the database supports + // interactive sessions using database REPLs. + SupportsInteractive bool `json:"supports_interactive,omitempty"` } // AWS contains AWS specific fields. @@ -322,22 +325,35 @@ const ( LabelStatus = "status" ) +// DatabaseInteractiveChecker is used to check if the database supports +// interactive sessions using database REPLs. +type DatabaseInteractiveChecker interface { + IsSupported(protocol string) bool +} + // MakeDatabase creates database objects. -func MakeDatabase(database types.Database, dbUsers, dbNames []string, requiresRequest bool) Database { +func MakeDatabase(database types.Database, accessChecker services.AccessChecker, interactiveChecker DatabaseInteractiveChecker, requiresRequest bool) Database { + dbNames := accessChecker.EnumerateDatabaseNames(database) + var dbUsers []string + if res, err := accessChecker.EnumerateDatabaseUsers(database); err == nil { + dbUsers = res.Allowed() + } + uiLabels := ui.MakeLabelsWithoutInternalPrefixes(database.GetAllLabels()) db := Database{ - Kind: database.GetKind(), - Name: database.GetName(), - Desc: database.GetDescription(), - Protocol: database.GetProtocol(), - Type: database.GetType(), - Labels: uiLabels, - DatabaseUsers: dbUsers, - DatabaseNames: dbNames, - Hostname: stripProtocolAndPort(database.GetURI()), - URI: database.GetURI(), - RequiresRequest: requiresRequest, + Kind: database.GetKind(), + Name: database.GetName(), + Desc: database.GetDescription(), + Protocol: database.GetProtocol(), + Type: database.GetType(), + Labels: uiLabels, + DatabaseUsers: dbUsers, + DatabaseNames: dbNames.Allowed(), + Hostname: stripProtocolAndPort(database.GetURI()), + URI: database.GetURI(), + RequiresRequest: requiresRequest, + SupportsInteractive: interactiveChecker.IsSupported(database.GetProtocol()), } if database.IsAWSHosted() { @@ -355,10 +371,10 @@ func MakeDatabase(database types.Database, dbUsers, dbNames []string, requiresRe } // MakeDatabases creates database objects. -func MakeDatabases(databases []*types.DatabaseV3, dbUsers, dbNames []string) []Database { +func MakeDatabases(databases []*types.DatabaseV3, accessChecker services.AccessChecker, interactiveChecker DatabaseInteractiveChecker) []Database { uiServers := make([]Database, 0, len(databases)) for _, database := range databases { - db := MakeDatabase(database, dbUsers, dbNames, false /* requiresRequest */) + db := MakeDatabase(database, accessChecker, interactiveChecker, false /* requiresRequest */) uiServers = append(uiServers, db) } diff --git a/lib/web/ui/server_test.go b/lib/web/ui/server_test.go index 9af621cf09557..514e9474b0c80 100644 --- a/lib/web/ui/server_test.go +++ b/lib/web/ui/server_test.go @@ -431,7 +431,8 @@ func TestMakeDatabaseHiddenLabels(t *testing.T) { }, } - outputDb := MakeDatabase(inputDb, nil, nil, false) + accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, "clusterName", nil) + outputDb := MakeDatabase(inputDb, accessChecker, &mockDatabaseInteractiveChecker{}, false) require.Equal(t, []ui.Label{ { @@ -590,3 +591,33 @@ func TestSortedLabels(t *testing.T) { }) } } + +func TestMakeDatabaseSupportsInteractive(t *testing.T) { + db := &types.DatabaseV3{} + accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, "clusterName", nil) + + for name, tc := range map[string]struct { + supports bool + }{ + "supported": {supports: true}, + "unsupported": {supports: false}, + } { + t.Run(name, func(t *testing.T) { + interactiveChecker := &mockDatabaseInteractiveChecker{supports: tc.supports} + single := MakeDatabase(db, accessChecker, interactiveChecker, false) + require.Equal(t, tc.supports, single.SupportsInteractive) + + multi := MakeDatabases([]*types.DatabaseV3{db}, accessChecker, interactiveChecker) + require.Len(t, multi, 1) + require.Equal(t, tc.supports, multi[0].SupportsInteractive) + }) + } +} + +type mockDatabaseInteractiveChecker struct { + supports bool +} + +func (m *mockDatabaseInteractiveChecker) IsSupported(_ string) bool { + return m.supports +} diff --git a/lib/web/ws_io.go b/lib/web/ws_io.go index 2ed1e4c548807..63055a376706d 100644 --- a/lib/web/ws_io.go +++ b/lib/web/ws_io.go @@ -20,11 +20,11 @@ package web import ( "context" + "log/slog" "time" "github.com/gorilla/websocket" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" ) type WebsocketIO struct { @@ -71,8 +71,8 @@ type wsPinger interface { // to prevent the connection between web client and teleport proxy from becoming idle. // Interval is determined by the keep_alive_interval config set by user (or default). // Loop will terminate when there is an error sending ping frame or when the context is canceled. -func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval time.Duration, log logrus.FieldLogger, onClose func() error) { - log.Debugf("Starting websocket ping loop with interval %v.", keepAliveInterval) +func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval time.Duration, log *slog.Logger, onClose func() error) { + log.DebugContext(ctx, "Starting websocket ping loop with interval", "interval", keepAliveInterval) tickerCh := time.NewTicker(keepAliveInterval) defer tickerCh.Stop() @@ -83,16 +83,16 @@ func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval tim // If this is just a temporary issue, we will retry shortly anyway. deadline := time.Now().Add(time.Second) if err := pinger.WriteControl(websocket.PingMessage, nil, deadline); err != nil { - log.WithError(err).Error("Unable to send ping frame to web client") + log.ErrorContext(ctx, "Unable to send ping frame to web client", "error", err) if onClose != nil { if err := onClose(); err != nil { - log.WithError(err).Error("OnClose handler failed") + log.ErrorContext(ctx, "OnClose handler failed", "error", err) } } return } case <-ctx.Done(): - log.Debug("Terminating websocket ping loop.") + log.DebugContext(ctx, "Terminating websocket ping loop.") return } }