From 4ff0f546b618288c76514102b637751e7aed0c91 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 16 Dec 2024 15:25:33 -0300 Subject: [PATCH] [v17] PostgreSQL access through WebUI (#50287) * Add PostgreSQL REPL implementation (#49598) * feat(repl): add postgres * refactor(repl): change repl to use a single Run function * test(repl): reduce usage of require.Eventually blocks * refactor(repl): code review suggestions * refactor(repl): code review suggestions * test(repl): increase timeout values * fix(repl): commands formatting * refactor(repl): send close pgconn using a different context * fix(repl): add proper spacing between multi queries * test(repl): add fuzz test for processing commands * Add WebSocket handler for WebUI database sessions (#49749) * feat(web): add websocket handler for database webui sessions * refactor: move common structs into a separate package * refactor(web): use ALPN local proxy to dial databases * feat(repl): add default registry * refactor(web): code review suggestions * refactor: update repl config parameters * refactor: move default getter implementation * feat(web): add supports_interactive field on dbs * refactor: code review suggestions * refactor: update database REPL interfaces * chore(web): remove debug print * feat: register postgres repl * refactor(web): update MakeDatabase to receive access checker and interactive * chore(web): remove unused function * Database access through WebUI (#49979) * feat(web): add database terminal access * chore(web): make explict type cast * refactor(web): code review suggestions * chore(web): fix lint errors * refactor(web): lint errors * refactor: code review suggestions * refactor(web): filter wildcard options from connect dialog * chore(web): lint * refactor(web): code review suggestions --- lib/client/alpn.go | 27 +- lib/client/db/postgres/repl/commands.go | 123 +++++ lib/client/db/postgres/repl/commands_test.go | 185 +++++++ lib/client/db/postgres/repl/repl.go | 303 +++++++++++ lib/client/db/postgres/repl/repl_test.go | 503 ++++++++++++++++++ .../repl/testdata/TestStart/data_type.golden | 4 + .../repl/testdata/TestStart/err.golden | 1 + .../repl/testdata/TestStart/multi.golden | 5 + .../repl/testdata/TestStart/multiquery.golden | 10 + .../repl/testdata/TestStart/single.golden | 4 + lib/client/db/postgres/repl/testdata/query.go | 111 ++++ lib/client/db/repl/repl.go | 80 +++ lib/defaults/defaults.go | 4 + lib/service/service.go | 9 + lib/service/servicecfg/config.go | 5 + lib/web/apiserver.go | 19 +- lib/web/apiserver_test.go | 10 +- lib/web/databases.go | 335 +++++++++++- lib/web/databases_test.go | 220 ++++++++ lib/web/integrations_awsoidc.go | 7 +- lib/web/kube.go | 4 +- lib/web/servers.go | 34 +- lib/web/terminal.go | 15 +- lib/web/ui/server.go | 44 +- lib/web/ui/server_test.go | 33 +- lib/web/ws_io.go | 12 +- web/packages/teleport/src/Console/Console.tsx | 7 +- .../src/Console/DocumentDb/ConnectDialog.tsx | 219 ++++++++ .../Console/DocumentDb/DocumentDb.story.tsx | 202 +++++++ .../Console/DocumentDb/DocumentDb.test.tsx | 137 +++++ .../src/Console/DocumentDb/DocumentDb.tsx | 78 +++ .../teleport/src/Console/DocumentDb/index.ts | 19 + .../src/Console/DocumentDb/useDbSession.tsx | 133 +++++ .../Console/DocumentSsh/Terminal/Terminal.tsx | 5 + .../teleport/src/Console/consoleContext.tsx | 36 +- .../teleport/src/Console/stores/storeDocs.ts | 4 + .../teleport/src/Console/stores/types.ts | 9 +- .../teleport/src/Console/useTabRouting.ts | 11 +- .../ConnectDialog/ConnectDialog.story.tsx | 12 + .../Databases/ConnectDialog/ConnectDialog.tsx | 26 +- .../UnifiedResources/ResourceActionButton.tsx | 3 +- web/packages/teleport/src/config.ts | 16 + .../teleport/src/lib/term/protobuf.ts | 6 + web/packages/teleport/src/lib/term/tty.ts | 14 + .../src/services/databases/databases.test.ts | 12 + .../src/services/databases/makeDatabase.ts | 2 + .../teleport/src/services/databases/types.ts | 2 + .../teleport/src/services/session/types.ts | 20 +- 48 files changed, 2982 insertions(+), 98 deletions(-) create mode 100644 lib/client/db/postgres/repl/commands.go create mode 100644 lib/client/db/postgres/repl/commands_test.go create mode 100644 lib/client/db/postgres/repl/repl.go create mode 100644 lib/client/db/postgres/repl/repl_test.go create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/data_type.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/err.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/multi.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden create mode 100644 lib/client/db/postgres/repl/testdata/TestStart/single.golden create mode 100644 lib/client/db/postgres/repl/testdata/query.go create mode 100644 lib/client/db/repl/repl.go create mode 100644 web/packages/teleport/src/Console/DocumentDb/ConnectDialog.tsx create mode 100644 web/packages/teleport/src/Console/DocumentDb/DocumentDb.story.tsx create mode 100644 web/packages/teleport/src/Console/DocumentDb/DocumentDb.test.tsx create mode 100644 web/packages/teleport/src/Console/DocumentDb/DocumentDb.tsx create mode 100644 web/packages/teleport/src/Console/DocumentDb/index.ts create mode 100644 web/packages/teleport/src/Console/DocumentDb/useDbSession.tsx diff --git a/lib/client/alpn.go b/lib/client/alpn.go index ec4e920923da7..db58eb374701f 100644 --- a/lib/client/alpn.go +++ b/lib/client/alpn.go @@ -85,14 +85,33 @@ type ALPNAuthTunnelConfig struct { // RouteToDatabase contains the destination server that must receive the connection. // Specific for database proxying. RouteToDatabase proto.RouteToDatabase + + // TLSCert specifies the TLS certificate used on the proxy connection. + TLSCert *tls.Certificate +} + +func (c *ALPNAuthTunnelConfig) CheckAndSetDefaults(ctx context.Context) error { + if c.AuthClient == nil { + return trace.BadParameter("missing auth client") + } + + if c.TLSCert == nil { + tlsCert, err := getUserCerts(ctx, c.AuthClient, c.MFAResponse, c.Expires, c.RouteToDatabase, c.ConnectionDiagnosticID) + if err != nil { + return trace.BadParameter("failed to parse private key: %v", err) + } + + c.TLSCert = &tlsCert + } + + return nil } // RunALPNAuthTunnel runs a local authenticated ALPN proxy to another service. // At least one Route (which defines the service) must be defined func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error { - tlsCert, err := getUserCerts(ctx, cfg.AuthClient, cfg.MFAResponse, cfg.Expires, cfg.RouteToDatabase, cfg.ConnectionDiagnosticID) - if err != nil { - return trace.BadParameter("failed to parse private key: %v", err) + if err := cfg.CheckAndSetDefaults(ctx); err != nil { + return trace.Wrap(err) } lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{ @@ -101,7 +120,7 @@ func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error { Protocols: []alpn.Protocol{cfg.Protocol}, Listener: cfg.Listener, ParentContext: ctx, - Cert: tlsCert, + Cert: *cfg.TLSCert, }, alpnproxy.WithALPNConnUpgradeTest(ctx, getClusterCACertPool(cfg.AuthClient))) if err != nil { return trace.Wrap(err) diff --git a/lib/client/db/postgres/repl/commands.go b/lib/client/db/postgres/repl/commands.go new file mode 100644 index 0000000000000..07d2faf7a02aa --- /dev/null +++ b/lib/client/db/postgres/repl/commands.go @@ -0,0 +1,123 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "fmt" + "strings" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/asciitable" +) + +// processCommand receives a command call and return the reply and if the +// command terminates the session. +func (r *REPL) processCommand(line string) (string, bool) { + cmdStr, args, _ := strings.Cut(strings.TrimPrefix(line, commandPrefix), " ") + cmd, ok := r.commands[cmdStr] + if !ok { + return "Unknown command. Try \\? to show the list of supported commands." + lineBreak, false + } + + return cmd.ExecFunc(r, args) +} + +// commandType specify the command category. This is used to organize the +// commands, for example, when showing them in the help command. +type commandType string + +const ( + // commandTypeGeneral represents a general-purpose command type. + commandTypeGeneral commandType = "General" + // commandTypeConnection represents a command type related to connection + // operations. + commandTypeConnection = "Connection" +) + +// command represents a command that can be executed in the REPL. +type command struct { + // Type specifies the type of the command. + Type commandType + // Description provides a user-friendly explanation of what the command + // does. + Description string + // ExecFunc is the function to execute the command. The commands can either + // return a reply (that will be sent back to the client) as a string. It can + // terminate the REPL by returning bool on the second argument. + ExecFunc func(r *REPL, args string) (reply string, exit bool) +} + +func initCommands() map[string]*command { + return map[string]*command{ + "q": { + Type: commandTypeGeneral, + Description: "Terminates the session.", + ExecFunc: func(_ *REPL, _ string) (string, bool) { return "", true }, + }, + "teleport": { + Type: commandTypeGeneral, + Description: "Show Teleport interactive shell information, such as execution limitations.", + ExecFunc: func(_ *REPL, _ string) (string, bool) { + // Formats limitiations in a dash list. Example: + // - hello + // multi line + // - another item + var limitations strings.Builder + for _, l := range descriptiveLimitations { + limitations.WriteString("- " + strings.Join(strings.Split(l, "\n"), "\n ") + lineBreak) + } + + return fmt.Sprintf( + "Teleport PostgreSQL interactive shell (v%s)\n\nLimitations: \n%s", + teleport.Version, + limitations.String(), + ), false + }, + }, + "?": { + Type: commandTypeGeneral, + Description: "Show the list of supported commands.", + ExecFunc: func(r *REPL, _ string) (string, bool) { + typesTable := make(map[commandType]*asciitable.Table) + for cmdStr, cmd := range r.commands { + if _, ok := typesTable[cmd.Type]; !ok { + table := asciitable.MakeHeadlessTable(2) + typesTable[cmd.Type] = &table + } + + typesTable[cmd.Type].AddRow([]string{"\\" + cmdStr, cmd.Description}) + } + + var res strings.Builder + for cmdType, output := range typesTable { + res.WriteString(string(cmdType) + lineBreak) + output.AsBuffer().WriteTo(&res) + res.WriteString(lineBreak) + } + + return res.String(), false + }, + }, + "session": { + Type: commandTypeConnection, + Description: "Display information about the current session, like user, and database instance.", + ExecFunc: func(r *REPL, _ string) (string, bool) { + return fmt.Sprintf("Connected to %q instance at %q database as %q user.", r.route.ServiceName, r.route.Database, r.route.Username), false + }, + }, + } +} diff --git a/lib/client/db/postgres/repl/commands_test.go b/lib/client/db/postgres/repl/commands_test.go new file mode 100644 index 0000000000000..2a974d470601f --- /dev/null +++ b/lib/client/db/postgres/repl/commands_test.go @@ -0,0 +1,185 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + clientproto "github.com/gravitational/teleport/api/client/proto" +) + +func TestCommandExecution(t *testing.T) { + ctx := context.Background() + + for name, tt := range map[string]struct { + line string + commandResult string + expectedArgs string + expectUnknown bool + commandExit bool + }{ + "execute": {line: "\\test", commandResult: "test"}, + "execute with additional arguments": {line: "\\test a b", commandResult: "test", expectedArgs: "a b"}, + "execute with exit": {line: "\\test", commandExit: true}, + "execute with leading and trailing whitespace": {line: " \\test ", commandResult: "test"}, + "unknown command with semicolon": {line: "\\test;", expectUnknown: true}, + "unknown command": {line: "\\wrong", expectUnknown: true}, + "with special characters": {line: "\\special_chars_!@#$%^&*()}", expectUnknown: true}, + "empty command": {line: "\\", expectUnknown: true}, + } { + t.Run(name, func(t *testing.T) { + commandArgsChan := make(chan string, 1) + instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + runErrChan := make(chan error) + go func() { + runErrChan <- instance.Run(ctx) + }() + + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + // Reset available commands and add a test command so we can assert + // the command execution flow without relying in commands + // implementation or test server capabilities. + instance.commands = map[string]*command{ + "test": { + ExecFunc: func(r *REPL, args string) (string, bool) { + commandArgsChan <- args + return tt.commandResult, tt.commandExit + }, + }, + } + + writeLine(t, tc, tt.line) + if tt.expectUnknown { + reply := readUntilNextLead(t, tc) + require.True(t, strings.HasPrefix(strings.ToLower(reply), "unknown command")) + return + } + + select { + case args := <-commandArgsChan: + require.Equal(t, tt.expectedArgs, args) + case <-time.After(time.Second): + require.Fail(t, "expected to command args from test server but got nothing") + } + + // When the command exits, the REPL and the connections will be + // closed. + if tt.commandExit { + require.EventuallyWithT(t, func(t *assert.CollectT) { + var buf []byte + _, err := tc.conn.Read(buf[0:]) + assert.ErrorIs(t, err, io.EOF) + }, 5*time.Second, time.Millisecond) + + select { + case err := <-runErrChan: + require.NoError(t, err, "expected the REPL instance exit gracefully") + case <-time.After(5 * time.Second): + require.Fail(t, "expected REPL run to terminate but got nothing") + } + return + } + + reply := readUntilNextLead(t, tc) + require.Equal(t, tt.commandResult, reply) + + // Terminate the REPL run session and wait for the Run results. + cancel() + select { + case err := <-runErrChan: + require.ErrorIs(t, err, context.Canceled, "expected the REPL instance to finish running with error due to cancelation") + case <-time.After(5 * time.Second): + require.Fail(t, "expected REPL run to terminate but got nothing") + } + }) + } +} + +func TestCommands(t *testing.T) { + availableCmds := initCommands() + for cmdName, tc := range map[string]struct { + repl *REPL + args string + expectExit bool + assertCommandReply require.ValueAssertionFunc + }{ + "q": {expectExit: true}, + "teleport": { + assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.Contains(t, val, teleport.Version, "expected \\teleport command to include current Teleport version") + }, + }, + "?": { + repl: &REPL{commands: availableCmds}, + assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { + for cmd := range availableCmds { + require.Contains(t, val, cmd, "expected \\? command to include information about \\%s", cmd) + } + }, + }, + "session": { + repl: &REPL{route: clientproto.RouteToDatabase{ + ServiceName: "service", + Username: "username", + Database: "database", + }}, + assertCommandReply: func(t require.TestingT, val interface{}, _ ...interface{}) { + require.Contains(t, val, "service", "expected \\session command to contain service name") + require.Contains(t, val, "username", "expected \\session command to contain username") + require.Contains(t, val, "database", "expected \\session command to contain database name") + }, + }, + } { + t.Run(cmdName, func(t *testing.T) { + cmd, ok := availableCmds[cmdName] + require.True(t, ok, "expected command %q to be available at commands", cmdName) + reply, exit := cmd.ExecFunc(tc.repl, tc.args) + if tc.expectExit { + require.True(t, exit, "expected command to exit the REPL") + return + } + tc.assertCommandReply(t, reply) + }) + } +} + +func FuzzCommands(f *testing.F) { + f.Add("q") + f.Add("?") + f.Add("session") + f.Add("teleport") + + repl := &REPL{commands: make(map[string]*command)} + f.Fuzz(func(t *testing.T, line string) { + require.NotPanics(t, func() { + _, _ = repl.processCommand(line) + }) + }) +} diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go new file mode 100644 index 0000000000000..514d9160e3efb --- /dev/null +++ b/lib/client/db/postgres/repl/repl.go @@ -0,0 +1,303 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/jackc/pgconn" + "golang.org/x/term" + + "github.com/gravitational/teleport" + clientproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/asciitable" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" + "github.com/gravitational/teleport/lib/defaults" +) + +type REPL struct { + connConfig *pgconn.Config + client io.ReadWriteCloser + serverConn net.Conn + route clientproto.RouteToDatabase + term *term.Terminal + commands map[string]*command +} + +func New(_ context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) { + config, err := pgconn.ParseConfig(fmt.Sprintf("postgres://%s", hostnamePlaceholder)) + if err != nil { + return nil, trace.Wrap(err) + } + config.User = cfg.Route.Username + config.Database = cfg.Route.Database + config.ConnectTimeout = defaults.DatabaseConnectTimeout + config.RuntimeParams = map[string]string{ + applicationNameParamName: applicationNameParamValue, + } + config.TLSConfig = nil + + // Provide a lookup function to avoid having the hostname placeholder to + // resolve into something else. Note that the returned value won't be used. + config.LookupFunc = func(_ context.Context, _ string) ([]string, error) { + return []string{hostnamePlaceholder}, nil + } + config.DialFunc = func(_ context.Context, _, _ string) (net.Conn, error) { + return cfg.ServerConn, nil + } + + return &REPL{ + connConfig: config, + client: cfg.Client, + serverConn: cfg.ServerConn, + route: cfg.Route, + term: term.NewTerminal(cfg.Client, ""), + commands: initCommands(), + }, nil +} + +// Run starts and run the PostgreSQL REPL session. The provided context is used +// to interrupt the execution and clean up resources. +func (r *REPL) Run(ctx context.Context) error { + pgConn, err := pgconn.ConnectConfig(ctx, r.connConfig) + if err != nil { + return trace.Wrap(err) + } + defer func() { + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pgConn.Close(closeCtx) + }() + + // term.Terminal blocks reads/writes without respecting the context. The + // only thing that unblocks it is closing the underlaying connection (in + // our case r.client). On this goroutine we only watch for context + // cancelation and close the connection. This will unblocks all terminal + // reads/writes. + ctxCancelCh := make(chan struct{}) + defer close(ctxCancelCh) + go func() { + select { + case <-ctx.Done(): + _ = r.client.Close() + case <-ctxCancelCh: + } + }() + + if err := r.presentBanner(); err != nil { + return trace.Wrap(err) + } + + var ( + multilineAcc strings.Builder + readingMultiline bool + ) + + lead := lineLeading(r.route) + leadSpacing := strings.Repeat(" ", len(lead)) + r.term.SetPrompt(lineBreak + lead) + + for { + line, err := r.term.ReadLine() + if err != nil { + return trace.Wrap(formatTermError(ctx, err)) + } + + // ReadLine should always return the line without trailing line breaks, + // but we still require to remove trailing and leading spaces. + line = strings.TrimSpace(line) + + var reply string + switch { + case strings.HasPrefix(line, commandPrefix) && !readingMultiline: + var exit bool + reply, exit = r.processCommand(line) + if exit { + return nil + } + case strings.HasSuffix(line, executionRequestSuffix): + var query string + if readingMultiline { + multilineAcc.WriteString(lineBreak + line) + query = multilineAcc.String() + } else { + query = line + } + + // Reset multiline state. + multilineAcc.Reset() + readingMultiline = false + r.term.SetPrompt(lineBreak + lead) + + reply = formatResult(pgConn.Exec(ctx, query).ReadAll()) + lineBreak + default: + // If there wasn't a specific execution, we assume the input is + // multi-line. In this case, we need to accumulate the contents. + + // If this isn't the first line, add the line break as the + // ReadLine function removes it. + if readingMultiline { + multilineAcc.WriteString(lineBreak) + } + + readingMultiline = true + multilineAcc.WriteString(line) + r.term.SetPrompt(leadSpacing) + } + + if reply == "" { + continue + } + + if _, err := r.term.Write([]byte(reply)); err != nil { + return trace.Wrap(formatTermError(ctx, err)) + } + } +} + +// formatTermError changes the term.Terminal error to match caller expectations. +func formatTermError(ctx context.Context, err error) error { + // When context is canceled it will immediately lead read/write errors due + // to the closed connection. For this cases we return the context error. + if ctx.Err() != nil && (errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed)) { + return ctx.Err() + } + + return err +} + +func (r *REPL) presentBanner() error { + _, err := fmt.Fprintf( + r.term, + `Teleport PostgreSQL interactive shell (v%s) +Connected to %q instance as %q user. +Type \? for help.`, + teleport.Version, + r.route.GetServiceName(), + r.route.GetUsername()) + return trace.Wrap(err) +} + +// formatResult formats a pgconn.Exec result. +func formatResult(results []*pgconn.Result, err error) string { + if err != nil { + return errorReplyPrefix + err.Error() + } + + var ( + sb strings.Builder + resultsLen = len(results) + ) + for i, res := range results { + if !res.CommandTag.Select() { + return res.CommandTag.String() + } + + // build columns + var columns []string + for _, fd := range res.FieldDescriptions { + columns = append(columns, string(fd.Name)) + } + + table := asciitable.MakeTable(columns) + for _, row := range res.Rows { + rowData := make([]string, len(columns)) + for i, data := range row { + // The PostgreSQL package is responsible for transforming the + // row data into a readable format. + rowData[i] = string(data) + } + + table.AddRow(rowData) + } + + table.AsBuffer().WriteTo(&sb) + sb.WriteString(rowsText(len(res.Rows))) + + // Add line breaks to separate results. Except the last result, which + // will have line breaks added later in the reply. + if i != resultsLen-1 { + sb.WriteString(lineBreak + lineBreak) + } + } + + return sb.String() +} + +func lineLeading(route clientproto.RouteToDatabase) string { + return fmt.Sprintf("%s=> ", route.Database) +} + +func rowsText(count int) string { + rowTxt := "row" + if count > 1 { + rowTxt = "rows" + } + + return fmt.Sprintf("(%d %s)", count, rowTxt) +} + +const ( + // hostnamePlaceholder is the hostname used when connecting to the database. + // The pgconn functions require a hostname, however, since we already have + // the connection, we just need to provide a name to suppress this + // requirement. + hostnamePlaceholder = "repl" + // lineBreak represents a line break on the REPL. + lineBreak = "\r\n" + // commandPrefix is the prefix that identifies a REPL command. + commandPrefix = "\\" + // executionRequestSuffix is the suffix that indicates the input must be + // executed. + executionRequestSuffix = ";" + // errorReplyPrefix is the prefix presented when there is a execution error. + errorReplyPrefix = "ERR " +) + +const ( + // applicationNameParamName defines the application name parameter name. + // + // https://www.postgresql.org/docs/17/libpq-connect.html#LIBPQ-CONNECT-APPLICATION-NAME + applicationNameParamName = "application_name" + // applicationNameParamValue defines the application name parameter value. + applicationNameParamValue = "teleport-repl" +) + +// descriptiveLimitations defines a user-friendly text containing the REPL +// limitations. +var descriptiveLimitations = []string{ + `Query cancellation is not supported. Once a query is sent, its execution +cannot be canceled. Note that Teleport sends a terminate message to the database +when the database session terminates. This flow doesn't guarantee that any +running queries will be canceled. +See https://www.postgresql.org/docs/17/protocol-flow.html#PROTOCOL-FLOW-TERMINATION for more details on the termination flow.`, + // This limitation is due to our terminal emulator not fully supporting this + // shortcut's custom handler. Instead, it will close the terminal, leading + // to terminating the session. To avoid having users accidentally + // terminating their sessions, we're turning this off until we have a better + // solution and propose the behavior for it. + // + // This shortcut filtered out by the WebUI key handler. + "Pressing CTRL-C will have no effect in this shell.", +} diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go new file mode 100644 index 0000000000000..0aa03b84c8023 --- /dev/null +++ b/lib/client/db/postgres/repl/repl_test.go @@ -0,0 +1,503 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "errors" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + clientproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/client/db/postgres/repl/testdata" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" + "github.com/gravitational/teleport/lib/utils/golden" +) + +func TestStart(t *testing.T) { + ctx := context.Background() + _, tc := StartWithServer(t, ctx) + + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + writeLine(t, tc, singleRowQuery) + singleRowQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "single", []byte(singleRowQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "single")), singleRowQueryResult) + + writeLine(t, tc, multiRowQuery) + multiRowQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "multi", []byte(multiRowQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "multi")), multiRowQueryResult) + + writeLine(t, tc, errorQuery) + errorQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "err", []byte(errorQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "err")), errorQueryResult) + + writeLine(t, tc, dataTypesQuery) + dataTypeQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "data_type", []byte(dataTypeQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "data_type")), dataTypeQueryResult) + + writeLine(t, tc, multiQuery) + multiQueryResult := readUntilNextLead(t, tc) + if golden.ShouldSet() { + golden.SetNamed(t, "multiquery", []byte(multiQueryResult)) + } + require.Equal(t, string(golden.GetNamed(t, "multiquery")), multiQueryResult) +} + +// TestQuery given some input lines, the REPL should execute the expected +// query on the PostgreSQL test server. +func TestQuery(t *testing.T) { + ctx := context.Background() + _, tc := StartWithServer(t, ctx, WithCustomQueries()) + + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + for name, tt := range map[string]struct { + lines []string + expectedQuery string + }{ + "query": {lines: []string{"SELECT 1;"}, expectedQuery: "SELECT 1;"}, + "query multiple semicolons": {lines: []string{"SELECT 1; ;;"}, expectedQuery: "SELECT 1; ;;"}, + "query multiple semicolons with trailing space": {lines: []string{"SELECT 1; ;; "}, expectedQuery: "SELECT 1; ;;"}, + "multiline query": {lines: []string{"SELECT", "1", ";"}, expectedQuery: "SELECT\r\n1\r\n;"}, + "malformatted": {lines: []string{"SELECT err;"}, expectedQuery: "SELECT err;"}, + "query with special characters": {lines: []string{"SELECT 'special_chars_!@#$%^&*()';"}, expectedQuery: "SELECT 'special_chars_!@#$%^&*()';"}, + "leading and trailing whitespace": {lines: []string{" SELECT 1; "}, expectedQuery: "SELECT 1;"}, + "multiline with excessive whitespace": {lines: []string{" SELECT", " 1", " ;"}, expectedQuery: "SELECT\r\n1\r\n;"}, + // Commands should only be executed if they are at the beginning of the + // first line. + "with command in the middle": {lines: []string{"SELECT \\d 1;"}, expectedQuery: "SELECT \\d 1;"}, + "multiline with command in the middle": {lines: []string{"SELECT", "\\d", ";"}, expectedQuery: "SELECT\r\n\\d\r\n;"}, + "multiline with command in the last line": {lines: []string{"SELECT", "1", "\\d;"}, expectedQuery: "SELECT\r\n1\r\n\\d;"}, + } { + t.Run(name, func(t *testing.T) { + for _, line := range tt.lines { + writeLine(t, tc, line) + } + + select { + case query := <-tc.QueryChan(): + require.Equal(t, tt.expectedQuery, query) + case <-time.After(5 * time.Second): + require.Fail(t, "expected to receive query but got nothing") + } + + // Always expect a query reply from the server. + _ = readUntilNextLead(t, tc) + }) + } +} + +func TestClose(t *testing.T) { + for name, tt := range map[string]struct { + closeFunc func(tc *testCtx, cancelCtx context.CancelFunc) + expectTerminateMessage bool + }{ + "closed by context": { + closeFunc: func(_ *testCtx, cancelCtx context.CancelFunc) { + cancelCtx() + }, + expectTerminateMessage: true, + }, + "closed by server": { + closeFunc: func(tc *testCtx, _ context.CancelFunc) { + tc.CloseServer() + }, + expectTerminateMessage: false, + }, + } { + t.Run(name, func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + _, tc := StartWithServer(t, ctx) + // Consume the REPL banner. + _ = readUntilNextLead(t, tc) + + tt.closeFunc(tc, cancelFunc) + // After closing the REPL session, we expect any read/write to + // return error. In case the close wasn't effective we need to + // execute the read on a Eventually block to avoid blocking the + // test. + require.EventuallyWithT(t, func(t *assert.CollectT) { + var buf []byte + _, err := tc.conn.Read(buf[0:]) + assert.ErrorIs(t, err, io.EOF) + }, 5*time.Second, time.Millisecond) + + if !tt.expectTerminateMessage { + return + } + + select { + case <-tc.terminateChan: + case <-time.After(5 * time.Second): + require.Fail(t, "expected REPL to send terminate message but got nothing") + } + }) + } +} + +func writeLine(t *testing.T, c *testCtx, line string) { + t.Helper() + data := []byte(line + lineBreak) + + // When writing to the connection, the terminal emulator always writes back. + // If we don't consume those bytes, it will block the ReadLine call (as + // we're net.Pipe). + go func(conn net.Conn) { + buf := make([]byte, len(data)) + // We need to consume any additional replies made by the terminal + // emulator until we consume the line contents. + for { + n, err := conn.Read(buf[0:]) + if err != nil { + t.Logf("Error while terminal reply on write: %s", err) + break + } + + if string(buf[:n]) == line+lineBreak { + break + } + } + }(c.conn) + + // Given that the test connections are piped a problem with the reader side + // would lead into blocking writing. To avoid this scenario we're using + // the Eventually just to ensure a timeout on writing into the connections. + require.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := c.conn.Write(data) + assert.NoError(t, err) + }, 5*time.Second, time.Millisecond, "expected to write into the connection successfully") +} + +// readUntilNextLead reads the contents from the client connection until we +// reach the next leading prompt. +func readUntilNextLead(t *testing.T, c *testCtx) string { + t.Helper() + + var acc strings.Builder + for { + line := readLine(t, c) + if strings.HasPrefix(line, lineBreak+lineLeading(c.route)) { + break + } + + acc.WriteString(line) + } + return acc.String() +} + +func readLine(t *testing.T, c *testCtx) string { + t.Helper() + + var n int + buf := make([]byte, 1024) + // Given that the test connections are piped a problem with the writer side + // would lead into blocking reading. To avoid this scenario we're using + // the Eventually just to ensure a timeout on reading from the connections. + require.EventuallyWithT(t, func(t *assert.CollectT) { + var err error + n, err = c.conn.Read(buf[0:]) + assert.NoError(t, err) + assert.Greater(t, n, 0) + }, 5*time.Second, time.Millisecond) + return string(buf[:n]) +} + +type testCtx struct { + cfg *testCtxConfig + ctx context.Context + cancelFunc context.CancelFunc + + // conn is the connection used by tests to read/write from/to the REPL. + conn net.Conn + // clientConn is the connection passed to the REPL. + clientConn net.Conn + // serverConn is the fake database server connection (that works as a + // PostgreSQL instance). + serverConn net.Conn + // rawPgConn is the underlaying net.Conn used by pgconn client. + rawPgConn net.Conn + + route clientproto.RouteToDatabase + pgClient *pgproto3.Backend + errChan chan error + terminateChan chan struct{} + // queryChan handling custom queries is enabled the queries received by the + // test server will be sent to this channel. + queryChan chan string +} + +type testCtxConfig struct { + // skipREPLRun when set to true the REPL instance won't be executed. + skipREPLRun bool + // handleCustomQueries when set to true the PostgreSQL test server will + // accept any query sent and reply with success. + handleCustomQueries bool +} + +// testCtxOption represents a testCtx option. +type testCtxOption func(*testCtxConfig) + +// WithCustomQueries enables sending custom queries to the PostgreSQL test +// server. Note that when it is enabled, callers must consume the queries on the +// query channel. +func WithCustomQueries() testCtxOption { + return func(cfg *testCtxConfig) { + cfg.handleCustomQueries = true + } +} + +// WithSkipREPLRun disables automatically running the REPL instance. +func WithSkipREPLRun() testCtxOption { + return func(cfg *testCtxConfig) { + cfg.skipREPLRun = true + } +} + +// StartWithServer starts a REPL instance with a PostgreSQL test server capable +// of receiving and replying to queries. +func StartWithServer(t *testing.T, ctx context.Context, opts ...testCtxOption) (*REPL, *testCtx) { + t.Helper() + + cfg := &testCtxConfig{} + for _, opt := range opts { + opt(cfg) + } + + conn, clientConn := net.Pipe() + serverConn, pgConn := net.Pipe() + client := pgproto3.NewBackend(pgproto3.NewChunkReader(pgConn), pgConn) + ctx, cancelFunc := context.WithCancel(ctx) + tc := &testCtx{ + cfg: cfg, + ctx: ctx, + cancelFunc: cancelFunc, + conn: conn, + clientConn: clientConn, + serverConn: serverConn, + rawPgConn: pgConn, + pgClient: client, + errChan: make(chan error, 1), + terminateChan: make(chan struct{}), + queryChan: make(chan string), + } + + t.Cleanup(func() { + tc.close() + + select { + case err := <-tc.errChan: + require.NoError(t, err) + case <-time.After(5 * time.Second): + require.Fail(t, "expected to receive the test server close result but got nothing") + } + }) + + go func(c *testCtx) { + defer close(c.errChan) + if err := c.processMessages(); err != nil && !errors.Is(err, io.ErrClosedPipe) { + c.errChan <- err + } + }(tc) + + instance, err := New(ctx, &dbrepl.NewREPLConfig{Client: tc.clientConn, ServerConn: tc.serverConn, Route: tc.route}) + require.NoError(t, err) + + if !cfg.skipREPLRun { + // Start the REPL session and return to the caller a channel that will + // receive the execution result so it can assert REPL executions. + runCtx, cancelRun := context.WithCancel(ctx) + runErrChan := make(chan error, 1) + go func() { + runErrChan <- instance.Run(runCtx) + }() + t.Cleanup(func() { + cancelRun() + + select { + case err := <-runErrChan: + if !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) { + require.Fail(t, "expected the REPL instance to finish with context cancelation or server closed pipe but got %q", err) + } + case <-time.After(10 * time.Second): + require.Fail(t, "timeout while waiting for REPL Run result") + } + }) + } + + r, _ := instance.(*REPL) + return r, tc +} + +func (tc *testCtx) QueryChan() chan string { + return tc.queryChan +} + +func (tc *testCtx) CloseServer() { + tc.rawPgConn.Close() +} + +func (tc *testCtx) close() { + tc.serverConn.Close() + tc.clientConn.Close() +} + +func (tc *testCtx) processMessages() error { + defer tc.close() + + startupMessage, err := tc.pgClient.ReceiveStartupMessage() + if err != nil { + return trace.Wrap(err) + } + + switch msg := startupMessage.(type) { + case *pgproto3.StartupMessage: + // Accept auth and send ready for query. + if err := tc.pgClient.Send(&pgproto3.AuthenticationOk{}); err != nil { + return trace.Wrap(err) + } + + // Values on the backend key data are not relavant since we don't + // support canceling requests. + err := tc.pgClient.Send(&pgproto3.BackendKeyData{ + ProcessID: 0, + SecretKey: 123, + }) + if err != nil { + return trace.Wrap(err) + } + + if err := tc.pgClient.Send(&pgproto3.ReadyForQuery{}); err != nil { + return trace.Wrap(err) + } + default: + return trace.BadParameter("expected *pgproto3.StartupMessage, got: %T", msg) + } + + for { + message, err := tc.pgClient.Receive() + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return nil + } + + return trace.Wrap(err) + } + + var messages []pgproto3.BackendMessage + switch msg := message.(type) { + case *pgproto3.Query: + if tc.cfg.handleCustomQueries { + select { + case tc.queryChan <- msg.String: + messages = []pgproto3.BackendMessage{ + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("INSERT 0 1")}, + &pgproto3.ReadyForQuery{}, + } + case <-tc.ctx.Done(): + return trace.Wrap(tc.ctx.Err()) + } + + break // breaks the message switch case. + } + + switch msg.String { + case singleRowQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("id")}, {Name: []byte("email")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1"), []byte("alice@example.com")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, + } + case multiRowQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("id")}, {Name: []byte("email")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1"), []byte("alice@example.com")}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("2"), []byte("bob@example.com")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, + } + case dataTypesQuery: + messages = testdata.TestDataQueryResult + case multiQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("?column?")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{{Name: []byte("id")}, {Name: []byte("email")}}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("1"), []byte("alice@example.com")}}, + &pgproto3.DataRow{Values: [][]byte{[]byte("2"), []byte("bob@example.com")}}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, + } + case errorQuery: + messages = []pgproto3.BackendMessage{ + &pgproto3.ErrorResponse{Severity: "ERROR", Code: "42703", Message: "error"}, + &pgproto3.ReadyForQuery{}, + } + default: + return trace.BadParameter("unsupported query %q", msg.String) + + } + case *pgproto3.Terminate: + close(tc.terminateChan) + return nil + default: + return trace.BadParameter("unsupported message %#v", message) + } + + for _, message := range messages { + err := tc.pgClient.Send(message) + if err != nil { + return trace.Wrap(err) + } + } + } +} + +const ( + singleRowQuery = "SELECT * FROM users LIMIT 1;" + multiRowQuery = "SELECT * FROM users;" + multiQuery = "SELECT 1; SELECT * FROM users;" + dataTypesQuery = "SELECT * FROM test_data_types;" + errorQuery = "SELECT err;" +) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/data_type.golden b/lib/client/db/postgres/repl/testdata/TestStart/data_type.golden new file mode 100644 index 0000000000000..725af38776034 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/data_type.golden @@ -0,0 +1,4 @@ +serial_col int_col smallint_col bigint_col decimal_col numeric_col real_col double_col smallserial_col bigserial_col char_col varchar_col text_col boolean_col date_col time_col timetz_col timestamp_col timestamptz_col interval_col uuid_col json_col jsonb_col xml_col bytea_col inet_col cidr_col macaddr_col point_col line_col lseg_col box_col path_col polygon_col circle_col tsquery_col tsvector_col +---------- ------- ------------ ------------------- ----------- ----------- -------- ----------------- --------------- ------------- ---------- ------------------- ---------------- ----------- ---------- -------- ----------- ------------------- ---------------------- ----------------------------- ------------------------------------ ---------------- ---------------- --------------------------------------- ------------------------ ----------- -------------- ----------------- --------- -------- ------------- ----------- ------------------- ------------------- ---------- ------------- -------------------------------------------------- +1 42 32767 9223372036854775807 12345.67 98765.43210 3.14 2.718281828459045 1 1 A Sample varchar text Sample text data t 2024-11-29 12:34:56 12:34:56+03 2024-11-29 12:34:56 2024-11-29 09:34:56+00 1 year 2 mons 3 days 04:05:06 550e8400-e29b-41d4-a716-446655440000 {"key": "value"} {"key": "value"} XML content \x48656c6c6f20576f726c64 192.168.1.1 192.168.1.0/24 08:00:2b:01:02:03 (1,2) {1,-1,0} [(0,0),(1,1)] (1,1),(0,0) ((0,0),(1,1),(2,2)) ((0,0),(1,1),(1,0)) <(0,0),1> 'fat' & 'rat' 'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat' +(1 row) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/err.golden b/lib/client/db/postgres/repl/testdata/TestStart/err.golden new file mode 100644 index 0000000000000..1dd89d57178c7 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/err.golden @@ -0,0 +1 @@ +ERR ERROR: error (SQLSTATE 42703) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/multi.golden b/lib/client/db/postgres/repl/testdata/TestStart/multi.golden new file mode 100644 index 0000000000000..43b92f3157fbb --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/multi.golden @@ -0,0 +1,5 @@ +id email +-- ----------------- +1 alice@example.com +2 bob@example.com +(2 rows) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden b/lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden new file mode 100644 index 0000000000000..3d3724d7186be --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/multiquery.golden @@ -0,0 +1,10 @@ +?column? +-------- +1 +(1 row) + +id email +-- ----------------- +1 alice@example.com +2 bob@example.com +(2 rows) diff --git a/lib/client/db/postgres/repl/testdata/TestStart/single.golden b/lib/client/db/postgres/repl/testdata/TestStart/single.golden new file mode 100644 index 0000000000000..c6ac2ed5ce793 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/TestStart/single.golden @@ -0,0 +1,4 @@ +id email +-- ----------------- +1 alice@example.com +(1 row) diff --git a/lib/client/db/postgres/repl/testdata/query.go b/lib/client/db/postgres/repl/testdata/query.go new file mode 100644 index 0000000000000..6789b128c1f37 --- /dev/null +++ b/lib/client/db/postgres/repl/testdata/query.go @@ -0,0 +1,111 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package testdata + +import ( + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" +) + +// Contains a query result with the most common fields in PostgreSQL. +// This can be used to understand how the REPL deals with different data types. +// +// Sampled from https://github.com/postgres/postgres/blob/b6612aedc53a6bf069eba5e356a8421ad6426486/src/include/catalog/pg_type.dat +// PostgreSQL version 17.2 +var TestDataQueryResult = []pgproto3.BackendMessage{ + &pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + // TableOID and TableAttributeNumber values omitted. + {Name: []byte("serial_col"), DataTypeOID: 23, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("int_col"), DataTypeOID: 23, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("smallint_col"), DataTypeOID: 21, DataTypeSize: 2, TypeModifier: -1, Format: 0}, + {Name: []byte("bigint_col"), DataTypeOID: 20, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("decimal_col"), DataTypeOID: 1700, DataTypeSize: -1, TypeModifier: 655366, Format: 0}, + {Name: []byte("numeric_col"), DataTypeOID: 1700, DataTypeSize: -1, TypeModifier: 983049, Format: 0}, + {Name: []byte("real_col"), DataTypeOID: 700, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("double_col"), DataTypeOID: 701, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("smallserial_col"), DataTypeOID: 21, DataTypeSize: 2, TypeModifier: -1, Format: 0}, + {Name: []byte("bigserial_col"), DataTypeOID: 20, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("char_col"), DataTypeOID: 1042, DataTypeSize: -1, TypeModifier: 14, Format: 0}, + {Name: []byte("varchar_col"), DataTypeOID: 1043, DataTypeSize: -1, TypeModifier: 54, Format: 0}, + {Name: []byte("text_col"), DataTypeOID: 25, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("boolean_col"), DataTypeOID: 16, DataTypeSize: 1, TypeModifier: -1, Format: 0}, + {Name: []byte("date_col"), DataTypeOID: 1082, DataTypeSize: 4, TypeModifier: -1, Format: 0}, + {Name: []byte("time_col"), DataTypeOID: 1083, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("timetz_col"), DataTypeOID: 1266, DataTypeSize: 12, TypeModifier: -1, Format: 0}, + {Name: []byte("timestamp_col"), DataTypeOID: 1114, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("timestamptz_col"), DataTypeOID: 1184, DataTypeSize: 8, TypeModifier: -1, Format: 0}, + {Name: []byte("interval_col"), DataTypeOID: 1186, DataTypeSize: 16, TypeModifier: -1, Format: 0}, + {Name: []byte("uuid_col"), DataTypeOID: 2950, DataTypeSize: 16, TypeModifier: -1, Format: 0}, + {Name: []byte("json_col"), DataTypeOID: 114, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("jsonb_col"), DataTypeOID: 3802, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("xml_col"), DataTypeOID: 142, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("bytea_col"), DataTypeOID: 17, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("inet_col"), DataTypeOID: 869, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("cidr_col"), DataTypeOID: 650, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("macaddr_col"), DataTypeOID: 829, DataTypeSize: 6, TypeModifier: -1, Format: 0}, + {Name: []byte("point_col"), DataTypeOID: 600, DataTypeSize: 16, TypeModifier: -1, Format: 0}, + {Name: []byte("line_col"), DataTypeOID: 628, DataTypeSize: 24, TypeModifier: -1, Format: 0}, + {Name: []byte("lseg_col"), DataTypeOID: 601, DataTypeSize: 32, TypeModifier: -1, Format: 0}, + {Name: []byte("box_col"), DataTypeOID: 603, DataTypeSize: 32, TypeModifier: -1, Format: 0}, + {Name: []byte("path_col"), DataTypeOID: 602, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("polygon_col"), DataTypeOID: 604, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("circle_col"), DataTypeOID: 718, DataTypeSize: 24, TypeModifier: -1, Format: 0}, + {Name: []byte("tsquery_col"), DataTypeOID: 3615, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + {Name: []byte("tsvector_col"), DataTypeOID: 3614, DataTypeSize: -1, TypeModifier: -1, Format: 0}, + }}, + &pgproto3.DataRow{Values: [][]byte{ + []byte("1"), + []byte("42"), + []byte("32767"), + []byte("9223372036854775807"), + []byte("12345.67"), + []byte("98765.43210"), + []byte("3.14"), + []byte("2.718281828459045"), + []byte("1"), + []byte("1"), + []byte("A "), + []byte("Sample varchar text"), + []byte("Sample text data"), + []byte("t"), + []byte("2024-11-29"), + []byte("12:34:56"), + []byte("12:34:56+03"), + []byte("2024-11-29 12:34:56"), + []byte("2024-11-29 09:34:56+00"), + []byte("1 year 2 mons 3 days 04:05:06"), + []byte("550e8400-e29b-41d4-a716-446655440000"), + []byte("{\"key\": \"value\"}"), + []byte("{\"key\": \"value\"}"), + []byte("XML content"), + []byte("\\x48656c6c6f20576f726c64"), + []byte("192.168.1.1"), + []byte("192.168.1.0/24"), + []byte("08:00:2b:01:02:03"), + []byte("(1,2)"), + []byte("{1,-1,0}"), + []byte("[(0,0),(1,1)]"), + []byte("(1,1),(0,0)"), + []byte("((0,0),(1,1),(2,2))"), + []byte("((0,0),(1,1),(1,0))"), + []byte("<(0,0),1>"), + []byte("'fat' & 'rat'"), + []byte("'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat'"), + }}, + &pgproto3.CommandComplete{CommandTag: pgconn.CommandTag("SELECT")}, + &pgproto3.ReadyForQuery{}, +} diff --git a/lib/client/db/repl/repl.go b/lib/client/db/repl/repl.go new file mode 100644 index 0000000000000..abfed3dd5b8b6 --- /dev/null +++ b/lib/client/db/repl/repl.go @@ -0,0 +1,80 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package repl + +import ( + "context" + "io" + "net" + + "github.com/gravitational/trace" + + clientproto "github.com/gravitational/teleport/api/client/proto" +) + +// NewREPLConfig represents the database REPL constructor config. +type NewREPLConfig struct { + // Client is the user terminal client. + Client io.ReadWriteCloser + // ServerConn is the database server connection. + ServerConn net.Conn + // Route is the session routing information. + Route clientproto.RouteToDatabase +} + +// REPLNewFunc defines the constructor function for database REPL +// sessions. +type REPLNewFunc func(context.Context, *NewREPLConfig) (REPLInstance, error) + +// REPLInstance represents a REPL instance. +type REPLInstance interface { + // Run executes the REPL. This is a blocking operation. + Run(context.Context) error +} + +// REPLRegistry is an interface for initializing REPL instances and checking +// if the database protocol is supported. +type REPLRegistry interface { + // IsSupported returns if a database protocol is supported by any REPL. + IsSupported(protocol string) bool + // NewInstance initializes a new REPL instance given the configuration. + NewInstance(context.Context, *NewREPLConfig) (REPLInstance, error) +} + +// NewREPLGetter creates a new REPL getter given the list of supported REPLs. +func NewREPLGetter(replNewFuncs map[string]REPLNewFunc) REPLRegistry { + return &replRegistry{m: replNewFuncs} +} + +type replRegistry struct { + m map[string]REPLNewFunc +} + +// IsSupported implements REPLGetter. +func (r *replRegistry) IsSupported(protocol string) bool { + _, supported := r.m[protocol] + return supported +} + +// NewInstance implements REPLGetter. +func (r *replRegistry) NewInstance(ctx context.Context, cfg *NewREPLConfig) (REPLInstance, error) { + if newFunc, ok := r.m[cfg.Route.Protocol]; ok { + return newFunc(ctx, cfg) + } + + return nil, trace.NotImplemented("REPL not supported for protocol %q", cfg.Route.Protocol) +} diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index f2aa8ae7f0be0..9597dbcd46016 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -721,6 +721,10 @@ const ( // WebsocketKubeExec provides latency information for a session. WebsocketKubeExec = "k" + + // WebsocketDatabaseSessionRequest is received when a new database session + // is requested. + WebsocketDatabaseSessionRequest = "d" ) // The following are cryptographic primitives Teleport does not support in diff --git a/lib/service/service.go b/lib/service/service.go index 347f5f2a1a1c9..ec5b379896f7b 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -107,6 +107,8 @@ import ( _ "github.com/gravitational/teleport/lib/backend/pgbk" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/cache" + pgrepl "github.com/gravitational/teleport/lib/client/db/postgres/repl" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/cloud/imds" @@ -1084,6 +1086,12 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { cfg.PluginRegistry = plugin.NewRegistry() } + if cfg.DatabaseREPLRegistry == nil { + cfg.DatabaseREPLRegistry = dbrepl.NewREPLGetter(map[string]dbrepl.REPLNewFunc{ + defaults.ProtocolPostgres: pgrepl.New, + }) + } + var cloudLabels labels.Importer // Check if we're on a cloud instance, and if we should override the node's hostname. @@ -4652,6 +4660,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels, IntegrationAppHandler: connectionsHandler, FeatureWatchInterval: retryutils.HalfJitter(web.DefaultFeatureWatchInterval * 2), + DatabaseREPLRegistry: cfg.DatabaseREPLRegistry, } webHandler, err := web.NewHandler(webConfig) if err != nil { diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go index d3dd237da74cd..6a14f1ceba5d0 100644 --- a/lib/service/servicecfg/config.go +++ b/lib/service/servicecfg/config.go @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/lite" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cloud/imds" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -265,6 +266,10 @@ type Config struct { // AccessGraph represents AccessGraph server config AccessGraph AccessGraphConfig + // DatabaseREPLRegistry is used to retrieve datatabase REPL given the + // protocol. + DatabaseREPLRegistry dbrepl.REPLRegistry + // token is either the token needed to join the auth server, or a path pointing to a file // that contains the token // diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 73752db274baa..43e4c7bd179bd 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -83,6 +83,7 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/client" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/client/sso" "github.com/gravitational/teleport/lib/defaults" dtconfig "github.com/gravitational/teleport/lib/devicetrust/config" @@ -332,6 +333,9 @@ type Config struct { // FeatureWatchInterval is the interval between pings to the auth server // to fetch new cluster features FeatureWatchInterval time.Duration + + // DatabaseREPLRegistry is used for retrieving database REPL. + DatabaseREPLRegistry dbrepl.REPLRegistry } // SetDefaults ensures proper default values are set if @@ -837,6 +841,7 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions h.GET("/webapi/sites/:site/kube/exec/ws", h.WithClusterAuthWebSocket(h.podConnect)) // connect to a pod with exec (via websocket, with auth over websocket) + h.GET("/webapi/sites/:site/db/exec/ws", h.WithClusterAuthWebSocket(h.dbConnect)) // Audit events handlers. h.GET("/webapi/sites/:site/events/search", h.WithClusterAuth(h.clusterSearchEvents)) // search site events @@ -3055,9 +3060,6 @@ func (h *Handler) clusterUnifiedResourcesGet(w http.ResponseWriter, request *htt getUserGroupLookup := h.getUserGroupLookup(request.Context(), clt) - var dbNames, dbUsers []string - hasFetchedDBUsersAndNames := false - unifiedResources := make([]any, 0, len(page)) for _, enriched := range page { switch r := enriched.ResourceWithLabels.(type) { @@ -3069,14 +3071,7 @@ func (h *Handler) clusterUnifiedResourcesGet(w http.ResponseWriter, request *htt unifiedResources = append(unifiedResources, ui.MakeServer(site.GetName(), r, logins, enriched.RequiresRequest)) case types.DatabaseServer: - if !hasFetchedDBUsersAndNames { - dbNames, dbUsers, err = getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - hasFetchedDBUsersAndNames = true - } - db := ui.MakeDatabase(r.GetDatabase(), dbUsers, dbNames, enriched.RequiresRequest) + db := ui.MakeDatabase(r.GetDatabase(), accessChecker, h.cfg.DatabaseREPLRegistry, enriched.RequiresRequest) unifiedResources = append(unifiedResources, db) case types.AppServer: allowedAWSRoles, err := calculateAppLogins(accessChecker, r, enriched.Logins) @@ -3570,6 +3565,7 @@ func (h *Handler) siteNodeConnect( } term, err := NewTerminal(ctx, TerminalHandlerConfig{ + Logger: h.logger, Term: req.Term, SessionCtx: sessionCtx, UserAuthClient: clt, @@ -3722,6 +3718,7 @@ func (h *Handler) podConnect( ws: ws, keepAliveInterval: keepAliveInterval, log: h.log.WithField(teleport.ComponentKey, "pod"), + logger: h.logger.With(teleport.ComponentKey, "pod"), userClient: clt, localCA: hostCA, configServerAddr: serverAddr, diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index ff0f12fdc20cb..1925f67e964c2 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -111,6 +111,7 @@ import ( "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/conntest" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -210,6 +211,9 @@ type webSuiteConfig struct { // clock to use for all server components clock clockwork.FakeClock + + // databaseREPLGetter allows setting custom database REPLs. + databaseREPLGetter dbrepl.REPLRegistry } func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { @@ -509,6 +513,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { return &proxyClientCert, nil }, IntegrationAppHandler: &mockIntegrationAppHandler{}, + DatabaseREPLRegistry: cfg.databaseREPLGetter, } if handlerConfig.HealthCheckAppServer == nil { @@ -7437,6 +7442,7 @@ func TestOverwriteDatabase(t *testing.T) { env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "user", nil /* roles */) + accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, env.server.ClusterName(), nil) initDb, err := types.NewDatabaseV3(types.Metadata{ Name: "postgres", @@ -7477,7 +7483,8 @@ func TestOverwriteDatabase(t *testing.T) { backendDb, err := env.server.Auth().GetDatabase(context.Background(), req.Name) require.NoError(t, err) - require.Equal(t, webui.MakeDatabase(backendDb, nil, nil, false), gotDb) + + require.Equal(t, webui.MakeDatabase(backendDb, accessChecker, proxy.handler.handler.cfg.DatabaseREPLRegistry, false), gotDb) }, }, { @@ -8390,6 +8397,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula return &proxyClientCert, nil }, IntegrationAppHandler: &mockIntegrationAppHandler{}, + DatabaseREPLRegistry: &mockDatabaseREPLRegistry{repl: map[string]dbrepl.REPLNewFunc{}}, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(clock)) require.NoError(t, err) diff --git a/lib/web/databases.go b/lib/web/databases.go index 557385d90f84c..f84beeb4ef7a8 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -21,28 +21,46 @@ package web import ( "context" "crypto/sha1" + "crypto/tls" "encoding/base64" "encoding/json" "encoding/pem" + "errors" "fmt" + "io" + "log/slog" "net" "net/http" "net/url" + "time" + gogoproto "github.com/gogo/protobuf/proto" + "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" + oteltrace "go.opentelemetry.io/otel/trace" "github.com/gravitational/teleport/api/client/proto" + clientproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/tlsutils" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/client" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" + alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" "github.com/gravitational/teleport/lib/ui" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/listener" "github.com/gravitational/teleport/lib/web/scripts" + "github.com/gravitational/teleport/lib/web/terminal" webui "github.com/gravitational/teleport/lib/web/ui" ) @@ -141,12 +159,8 @@ func (h *Handler) handleDatabaseCreateOrOverwrite(w http.ResponseWriter, r *http if err != nil { return nil, trace.Wrap(err) } - dbNames, dbUsers, err := getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - return webui.MakeDatabase(database, dbUsers, dbNames, false /* requiresRequest */), nil + return webui.MakeDatabase(database, accessChecker, h.cfg.DatabaseREPLRegistry, false /* requiresRequest */), nil } // updateDatabaseRequest contains some updatable fields of a database resource. @@ -254,7 +268,12 @@ func (h *Handler) handleDatabasePartialUpdate(w http.ResponseWriter, r *http.Req return nil, trace.Wrap(err) } - return webui.MakeDatabase(database, nil /* dbUsers */, nil /* dbNames */, false /* requiresRequest */), nil + accessChecker, err := sctx.GetUserAccessChecker() + if err != nil { + return nil, trace.Wrap(err) + } + + return webui.MakeDatabase(database, accessChecker, h.cfg.DatabaseREPLRegistry, false /* requiresRequest */), nil } // databaseIAMPolicyResponse is the response type for handleDatabaseGetIAMPolicy. @@ -387,6 +406,310 @@ func (h *Handler) sqlServerConfigureADScriptHandle(w http.ResponseWriter, r *htt return nil, trace.Wrap(err) } +func (h *Handler) dbConnect( + w http.ResponseWriter, + r *http.Request, + p httprouter.Params, + sctx *SessionContext, + site reversetunnelclient.RemoteSite, + ws *websocket.Conn, +) (interface{}, error) { + // Create a context for signaling when the terminal session is over and + // link it first with the trace context from the request context + tctx := oteltrace.ContextWithRemoteSpanContext(context.Background(), oteltrace.SpanContextFromContext(r.Context())) + ctx, cancel := context.WithCancel(tctx) + defer cancel() + h.logger.DebugContext(ctx, "Received database interactive connection") + + req, err := readDatabaseSessionRequest(ws) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || terminal.IsOKWebsocketCloseError(trace.Unwrap(err)) { + h.logger.DebugContext(ctx, "Database interactive session closed before receiving request") + return nil, nil + } + + var netError net.Error + if errors.As(trace.Unwrap(err), &netError) && netError.Timeout() { + return nil, trace.BadParameter("timed out waiting for database connect request data on websocket connection") + } + + return nil, trace.Wrap(err) + } + + log := h.logger.With( + "protocol", req.Protocol, + "service_name", req.ServiceName, + "database_name", req.DatabaseName, + "database_user", req.DatabaseUser, + "database_roles", req.DatabaseRoles, + ) + log.DebugContext(ctx, "Received database interactive session request") + + if !h.cfg.DatabaseREPLRegistry.IsSupported(req.Protocol) { + log.ErrorContext(ctx, "Unsupported database protocol") + return nil, trace.NotImplemented("%q database protocol not supported for REPL sessions", req.Protocol) + } + + accessPoint, err := site.CachingAccessPoint() + if err != nil { + return nil, trace.Wrap(err) + } + + netConfig, err := accessPoint.GetClusterNetworkingConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + clt, err := sctx.GetUserClient(ctx, site) + if err != nil { + return nil, trace.Wrap(err) + } + + stream := terminal.NewStream(ctx, terminal.StreamConfig{WS: ws}) + defer stream.Close() + + replConn, alpnConn := net.Pipe() + sess := &databaseInteractiveSession{ + ctx: ctx, + log: log, + req: req, + stream: stream, + ws: ws, + sctx: sctx, + site: site, + clt: clt, + replConn: replConn, + alpnConn: alpnConn, + keepAliveInterval: netConfig.GetKeepAliveInterval(), + registry: h.cfg.DatabaseREPLRegistry, + proxyAddr: h.PublicProxyAddr(), + } + defer sess.Close() + + if err := sess.Run(); err != nil { + log.ErrorContext(ctx, "Database interactive session exited with error", "error", err) + return nil, trace.Wrap(err) + } + + return nil, nil +} + +// DatabaseSessionRequest describes a request to create a web-based terminal +// database session. +type DatabaseSessionRequest struct { + // ServiceName is the database resource ID the user will be connected. + ServiceName string `json:"serviceName"` + // Protocol is the database protocol. + Protocol string `json:"protocol"` + // DatabaseName is the database name the session will use. + DatabaseName string `json:"dbName"` + // DatabaseUser is the database user used on the session. + DatabaseUser string `json:"dbUser"` + // DatabaseRoles are ratabase roles that will be attached to the user when + // connecting to the database. + DatabaseRoles []string `json:"dbRoles"` +} + +// databaseConnectionRequestWaitTimeout defines how long the server will wait +// for the user to send the connection request. +const databaseConnectionRequestWaitTimeout = defaults.HeadlessLoginTimeout + +// readDatabaseSessionRequest reads the database session requestion message from +// websocket connection. +func readDatabaseSessionRequest(ws *websocket.Conn) (*DatabaseSessionRequest, error) { + err := ws.SetReadDeadline(time.Now().Add(databaseConnectionRequestWaitTimeout)) + if err != nil { + return nil, trace.Wrap(err, "failed to set read deadline for websocket connection") + } + + messageType, bytes, err := ws.ReadMessage() + if err != nil { + return nil, trace.Wrap(err) + } + + if err := ws.SetReadDeadline(time.Time{}); err != nil { + return nil, trace.Wrap(err, "failed to set read deadline for websocket connection") + } + + if messageType != websocket.BinaryMessage { + return nil, trace.BadParameter("expected binary message of type websocket.BinaryMessage, got %v", messageType) + } + + var envelope terminal.Envelope + if err := gogoproto.Unmarshal(bytes, &envelope); err != nil { + return nil, trace.BadParameter("failed to parse envelope: %v", err) + } + + if envelope.Type != defaults.WebsocketDatabaseSessionRequest { + return nil, trace.BadParameter("expected database session request but got %q", envelope.Type) + } + + var req DatabaseSessionRequest + if err := json.Unmarshal([]byte(envelope.Payload), &req); err != nil { + return nil, trace.Wrap(err) + } + + return &req, nil +} + +type databaseInteractiveSession struct { + ctx context.Context + ws *websocket.Conn + stream *terminal.Stream + log *slog.Logger + req *DatabaseSessionRequest + sctx *SessionContext + site reversetunnelclient.RemoteSite + clt authclient.ClientI + replConn net.Conn + alpnConn net.Conn + keepAliveInterval time.Duration + registry dbrepl.REPLRegistry + proxyAddr string +} + +func (s *databaseInteractiveSession) Run() error { + tlsCert, route, err := s.issueCerts() + if err != nil { + return trace.Wrap(err) + } + + if err := s.sendSessionMetadata(); err != nil { + return trace.Wrap(err) + } + + alpnProtocol, err := alpncommon.ToALPNProtocol(route.Protocol) + if err != nil { + return trace.Wrap(err) + } + + go startWSPingLoop(s.ctx, s.ws, s.keepAliveInterval, s.log, s.Close) + + err = client.RunALPNAuthTunnel(s.ctx, client.ALPNAuthTunnelConfig{ + AuthClient: s.clt, + Listener: listener.NewSingleUseListener(s.alpnConn), + Protocol: alpnProtocol, + PublicProxyAddr: s.proxyAddr, + RouteToDatabase: *route, + TLSCert: tlsCert, + }) + if err != nil { + return trace.Wrap(err) + } + + repl, err := s.registry.NewInstance(s.ctx, &dbrepl.NewREPLConfig{ + Client: s.stream, + ServerConn: s.replConn, + Route: *route, + }) + if err != nil { + return trace.Wrap(err) + } + + s.log.DebugContext(s.ctx, "Starting database interactive session") + if err := repl.Run(s.ctx); err != nil { + return trace.Wrap(err) + } + + s.log.DebugContext(s.ctx, "Database interactive session exited with success") + return nil +} + +func (s *databaseInteractiveSession) Close() error { + s.replConn.Close() + return s.ws.Close() +} + +// issueCerts performs the MFA (if required) and generate the user session +// certificates. +func (s *databaseInteractiveSession) issueCerts() (*tls.Certificate, *clientproto.RouteToDatabase, error) { + pk, err := keys.ParsePrivateKey(s.sctx.cfg.Session.GetTLSPriv()) + if err != nil { + return nil, nil, trace.Wrap(err, "failed getting user private key from the session") + } + + publicKeyPEM, err := keys.MarshalPublicKey(pk.Public()) + if err != nil { + return nil, nil, trace.Wrap(err, "failed to marshal public key") + } + + routeToDatabase := clientproto.RouteToDatabase{ + Protocol: s.req.Protocol, + ServiceName: s.req.ServiceName, + Username: s.req.DatabaseUser, + Database: s.req.DatabaseName, + Roles: s.req.DatabaseRoles, + } + + certsReq := clientproto.UserCertsRequest{ + TLSPublicKey: publicKeyPEM, + Username: s.sctx.GetUser(), + Expires: s.sctx.cfg.Session.GetExpiryTime(), + Format: constants.CertificateFormatStandard, + RouteToCluster: s.site.GetName(), + Usage: clientproto.UserCertsRequest_Database, + RouteToDatabase: routeToDatabase, + } + + _, certs, err := client.PerformSessionMFACeremony(s.ctx, client.PerformSessionMFACeremonyParams{ + CurrentAuthClient: s.clt, + RootAuthClient: s.sctx.cfg.RootClient, + MFACeremony: newMFACeremony(s.stream.WSStream, s.sctx.cfg.RootClient.CreateAuthenticateChallenge), + MFAAgainstRoot: s.sctx.cfg.RootClusterName == s.site.GetName(), + MFARequiredReq: &clientproto.IsMFARequiredRequest{ + Target: &clientproto.IsMFARequiredRequest_Database{Database: &routeToDatabase}, + }, + CertsReq: &certsReq, + }) + if err != nil && !errors.Is(err, services.ErrSessionMFANotRequired) { + return nil, nil, trace.Wrap(err, "failed performing mfa ceremony") + } + + if certs == nil { + certs, err = s.sctx.cfg.RootClient.GenerateUserCerts(s.ctx, certsReq) + if err != nil { + return nil, nil, trace.Wrap(err, "failed issuing user certs") + } + } + + tlsCert, err := pk.TLSCertificate(certs.TLS) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return &tlsCert, &routeToDatabase, nil +} + +func (s *databaseInteractiveSession) sendSessionMetadata() error { + sessionMetadataResponse, err := json.Marshal(siteSessionGenerateResponse{Session: session.Session{ + // TODO(gabrielcorado): Have a consistent Session ID. Right now, the + // initial session ID returned won't be correct as the session is only + // initialized by the database server after the REPL starts. + ClusterName: s.site.GetName(), + }}) + if err != nil { + return trace.Wrap(err) + } + + envelope := &terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketSessionMetadata, + Payload: string(sessionMetadataResponse), + } + + envelopeBytes, err := gogoproto.Marshal(envelope) + if err != nil { + return trace.Wrap(err) + } + + err = s.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + // fetchDatabaseWithName fetch a database with provided database name. func fetchDatabaseWithName(ctx context.Context, clt resourcesAPIGetter, r *http.Request, databaseName string) (types.Database, error) { resp, err := clt.ListResources(ctx, proto.ListResourcesRequest{ diff --git a/lib/web/databases_test.go b/lib/web/databases_test.go index 3a0dd78d212d0..4a25b13557d68 100644 --- a/lib/web/databases_test.go +++ b/lib/web/databases_test.go @@ -20,25 +20,35 @@ package web import ( "context" + "crypto/tls" "encoding/json" "fmt" "net/http" "net/url" "regexp" + "sync" "testing" "time" + "github.com/gogo/protobuf/proto" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/stretchr/testify/require" + authproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" + wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" + "github.com/gravitational/teleport/lib/client" + dbrepl "github.com/gravitational/teleport/lib/client/db/repl" awslib "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/web/terminal" "github.com/gravitational/teleport/lib/web/ui" ) @@ -524,6 +534,216 @@ func TestHandleSQLServerConfigureScriptDatabaseURIEscaped(t *testing.T) { } } +func TestConnectDatabaseInteractiveSession(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + databaseProtocol := defaults.ProtocolPostgres + + // Use a mock REPL and modify it adding the additional configuration when + // it is set. + repl := &mockDatabaseREPL{message: "hello from repl"} + + s := newWebSuiteWithConfig(t, webSuiteConfig{ + disableDiskBasedRecording: true, + authPreferenceSpec: &types.AuthPreferenceSpecV2{ + Type: constants.Local, + ConnectorName: constants.PasswordlessConnector, + SecondFactor: constants.SecondFactorOn, + RequireMFAType: types.RequireMFAType_SESSION, + Webauthn: &types.Webauthn{ + RPID: "localhost", + }, + }, + databaseREPLGetter: &mockDatabaseREPLRegistry{ + repl: map[string]dbrepl.REPLNewFunc{ + databaseProtocol: func(ctx context.Context, c *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) { + repl.setConfig(c) + return repl, nil + }, + }, + }, + }) + s.webHandler.handler.cfg.PublicProxyAddr = s.webHandler.handler.cfg.ProxyWebAddr.String() + + accessRole, err := types.NewRole("access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + DatabaseNames: []string{types.Wildcard}, + DatabaseUsers: []string{types.Wildcard}, + }, + }) + require.NoError(t, err) + pack := s.authPackWithMFA(t, "user", accessRole) + + databaseName := "db" + selfHosted, err := types.NewDatabaseV3(types.Metadata{ + Name: databaseName, + }, types.DatabaseSpecV3{ + Protocol: databaseProtocol, + URI: "localhost:12345", + }) + require.NoError(t, err) + + _, err = s.server.Auth().UpsertDatabaseServer(ctx, mustCreateDatabaseServer(t, selfHosted)) + require.NoError(t, err) + + u := url.URL{ + Host: s.webServer.Listener.Addr().String(), + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()), + } + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + ws, resp, err := dialer.DialContext(ctx, u.String(), header) + require.NoError(t, err) + defer ws.Close() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token)) + + req := DatabaseSessionRequest{ + Protocol: databaseProtocol, + ServiceName: databaseName, + DatabaseName: "postgres", + DatabaseUser: "postgres", + DatabaseRoles: []string{"reader"}, + } + encodedReq, err := json.Marshal(req) + require.NoError(t, err) + reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketDatabaseSessionRequest, + Payload: string(encodedReq), + }) + require.NoError(t, err) + require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage)) + + performMFACeremonyWS(t, ws, pack) + + // After the MFA is performed we expect the WebSocket to receive the + // session data information. + sessionData := receiveWSMessage(t, ws) + require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type) + + // Assert data written by the REPL comes as raw data. + replResp := receiveWSMessage(t, ws) + require.Equal(t, defaults.WebsocketRaw, replResp.Type) + require.Equal(t, repl.message, replResp.Payload) + + require.NoError(t, ws.Close()) + require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed") +} + +func receiveWSMessage(t *testing.T, ws *websocket.Conn) terminal.Envelope { + t.Helper() + + typ, raw, err := ws.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.BinaryMessage, typ) + var env terminal.Envelope + require.NoError(t, proto.Unmarshal(raw, &env)) + return env +} + +func performMFACeremonyWS(t *testing.T, ws *websocket.Conn, pack *authPack) { + t.Helper() + + ty, raw, err := ws.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.BinaryMessage, ty, "got unexpected websocket message type %d", ty) + + var env terminal.Envelope + require.NoError(t, proto.Unmarshal(raw, &env)) + + var challenge client.MFAAuthenticateChallenge + require.NoError(t, json.Unmarshal([]byte(env.Payload), &challenge)) + + res, err := pack.device.SolveAuthn(&authproto.MFAAuthenticateChallenge{ + WebauthnChallenge: wantypes.CredentialAssertionToProto(challenge.WebauthnChallenge), + }) + require.NoError(t, err) + + webauthnResBytes, err := json.Marshal(wantypes.CredentialAssertionResponseFromProto(res.GetWebauthn())) + require.NoError(t, err) + + envelopeBytes, err := proto.Marshal(&terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketMFAChallenge, + Payload: string(webauthnResBytes), + }) + require.NoError(t, err) + require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) +} + +type mockDatabaseREPLRegistry struct { + repl map[string]dbrepl.REPLNewFunc +} + +// NewInstance implements repl.REPLGetter. +func (m *mockDatabaseREPLRegistry) NewInstance(ctx context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, error) { + if replFunc, ok := m.repl[cfg.Route.Protocol]; ok { + return replFunc(ctx, cfg) + } + + return nil, trace.NotImplemented("not supported") +} + +// IsSupported implements repl.REPLGetter. +func (m *mockDatabaseREPLRegistry) IsSupported(protocol string) bool { + _, supported := m.repl[protocol] + return supported +} + +type mockDatabaseREPL struct { + mu sync.Mutex + message string + cfg *dbrepl.NewREPLConfig + closed bool +} + +func (m *mockDatabaseREPL) Run(_ context.Context) error { + m.mu.Lock() + defer func() { + m.closeUnlocked() + m.mu.Unlock() + }() + + if _, err := m.cfg.Client.Write([]byte(m.message)); err != nil { + return trace.Wrap(err) + } + + if _, err := m.cfg.ServerConn.Write([]byte("Hello")); err != nil { + return trace.Wrap(err) + } + + return nil +} + +func (m *mockDatabaseREPL) setConfig(c *dbrepl.NewREPLConfig) { + m.mu.Lock() + defer m.mu.Unlock() + m.cfg = c +} + +func (m *mockDatabaseREPL) getClosed() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.closed +} + +func (m *mockDatabaseREPL) closeUnlocked() { + m.closed = true +} + func mustCreateDatabaseServer(t *testing.T, db *types.DatabaseV3) types.DatabaseServer { t.Helper() diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 331df6ba78dd1..710250757415b 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -87,9 +87,14 @@ func (h *Handler) awsOIDCListDatabases(w http.ResponseWriter, r *http.Request, p return nil, trace.Wrap(err) } + accessChecker, err := sctx.GetUserAccessChecker() + if err != nil { + return nil, trace.Wrap(err) + } + return ui.AWSOIDCListDatabasesResponse{ NextToken: listDatabasesResp.NextToken, - Databases: ui.MakeDatabases(listDatabasesResp.Databases, nil, nil), + Databases: ui.MakeDatabases(listDatabasesResp.Databases, accessChecker, h.cfg.DatabaseREPLRegistry), }, nil } diff --git a/lib/web/kube.go b/lib/web/kube.go index aad3a0a25c817..ccfd76380103f 100644 --- a/lib/web/kube.go +++ b/lib/web/kube.go @@ -23,6 +23,7 @@ import ( "context" "encoding/json" "errors" + "log/slog" "net/http" "strings" "sync/atomic" @@ -62,6 +63,7 @@ type podHandler struct { ws *websocket.Conn keepAliveInterval time.Duration log *logrus.Entry + logger *slog.Logger userClient authclient.ClientI localCA types.CertAuthority @@ -207,7 +209,7 @@ func (p *podHandler) handler(r *http.Request) error { }) // Start sending ping frames through websocket to the client. - go startWSPingLoop(r.Context(), p.ws, p.keepAliveInterval, p.log, p.Close) + go startWSPingLoop(r.Context(), p.ws, p.keepAliveInterval, p.logger, p.Close) pk, err := keys.ParsePrivateKey(p.sctx.cfg.Session.GetTLSPriv()) if err != nil { diff --git a/lib/web/servers.go b/lib/web/servers.go index af76b4d924620..8babf48dbb1bf 100644 --- a/lib/web/servers.go +++ b/lib/web/servers.go @@ -29,7 +29,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/ui" webui "github.com/gravitational/teleport/lib/web/ui" ) @@ -125,13 +124,8 @@ func (h *Handler) clusterDatabasesGet(w http.ResponseWriter, r *http.Request, p return nil, trace.Wrap(err) } - dbNames, dbUsers, err := getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - return listResourcesGetResponse{ - Items: webui.MakeDatabases(databases, dbUsers, dbNames), + Items: webui.MakeDatabases(databases, accessChecker, h.cfg.DatabaseREPLRegistry), StartKey: page.NextKey, TotalCount: page.Total, }, nil @@ -159,12 +153,7 @@ func (h *Handler) clusterDatabaseGet(w http.ResponseWriter, r *http.Request, p h return nil, trace.Wrap(err) } - dbNames, dbUsers, err := getDatabaseUsersAndNames(accessChecker) - if err != nil { - return nil, trace.Wrap(err) - } - - return webui.MakeDatabase(database, dbUsers, dbNames, false /* requiresRequest */), nil + return webui.MakeDatabase(database, accessChecker, h.cfg.DatabaseREPLRegistry, false /* requiresRequest */), nil } // clusterDatabaseServicesList returns a list of DatabaseServices (database agents) in a form the UI can present. @@ -333,25 +322,6 @@ func (h *Handler) desktopIsActive(w http.ResponseWriter, r *http.Request, p http return desktopIsActive{false}, nil } -func getDatabaseUsersAndNames(accessChecker services.AccessChecker) (dbNames []string, dbUsers []string, err error) { - dbNames, dbUsers, err = accessChecker.CheckDatabaseNamesAndUsers(0, true /* force ttl override*/) - if err != nil { - // if NotFound error: - // This user cannot request database access, has no assigned database names or users - // - // Every other error should be reported upstream. - if !trace.IsNotFound(err) { - return nil, nil, trace.Wrap(err) - } - - // We proceed with an empty list of DBUsers and DBNames - dbUsers = []string{} - dbNames = []string{} - } - - return dbNames, dbUsers, nil -} - type desktopIsActive struct { Active bool `json:"active"` } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 072e51ecf1140..9326140447eac 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" "net/url" @@ -126,6 +127,10 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl teleport.ComponentKey: teleport.ComponentWebsocket, "session_id": cfg.SessionData.ID.String(), }), + logger: cfg.Logger.With( + teleport.ComponentKey, teleport.ComponentWebsocket, + "session_id", cfg.SessionData.ID.String(), + ), ctx: cfg.SessionCtx, userAuthClient: cfg.UserAuthClient, localAccessPoint: cfg.LocalAccessPoint, @@ -152,6 +157,8 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl // TerminalHandlerConfig contains the configuration options necessary to // correctly set up the TerminalHandler type TerminalHandlerConfig struct { + // Logger specifies the logger. + Logger *slog.Logger // Term is the initial PTY size. Term session.TerminalParams // SessionCtx is the context for the users web session. @@ -205,6 +212,10 @@ type TerminalHandlerConfig struct { } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { + if t.Logger == nil { + t.Logger = slog.Default().With(teleport.ComponentKey, teleport.ComponentWebsocket) + } + // Make sure whatever session is requested is a valid session id. if !t.SessionData.ID.IsZero() { _, err := session.ParseID(t.SessionData.ID.String()) @@ -259,6 +270,8 @@ func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { type sshBaseHandler struct { // log holds the structured logger. log *logrus.Entry + // logger holds the structured logger. + logger *slog.Logger // ctx is a web session context for the currently logged-in user. ctx *SessionContext // userAuthClient is used to fetch nodes and sessions from the backend via the users' identity. @@ -470,7 +483,7 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { }) // Start sending ping frames through websocket to client. - go startWSPingLoop(ctx, ws, t.keepAliveInterval, t.log, t.Close) + go startWSPingLoop(ctx, ws, t.keepAliveInterval, t.logger, t.Close) // Pump raw terminal in/out and audit events into the websocket. go t.streamEvents(ctx, tc) diff --git a/lib/web/ui/server.go b/lib/web/ui/server.go index 9921307c48134..93c3781aa6172 100644 --- a/lib/web/ui/server.go +++ b/lib/web/ui/server.go @@ -306,6 +306,9 @@ type Database struct { AWS *AWS `json:"aws,omitempty"` // RequireRequest indicates if a returned resource is only accessible after an access request RequiresRequest bool `json:"requiresRequest,omitempty"` + // SupportsInteractive is a flag to indicate the database supports + // interactive sessions using database REPLs. + SupportsInteractive bool `json:"supports_interactive,omitempty"` } // AWS contains AWS specific fields. @@ -322,22 +325,35 @@ const ( LabelStatus = "status" ) +// DatabaseInteractiveChecker is used to check if the database supports +// interactive sessions using database REPLs. +type DatabaseInteractiveChecker interface { + IsSupported(protocol string) bool +} + // MakeDatabase creates database objects. -func MakeDatabase(database types.Database, dbUsers, dbNames []string, requiresRequest bool) Database { +func MakeDatabase(database types.Database, accessChecker services.AccessChecker, interactiveChecker DatabaseInteractiveChecker, requiresRequest bool) Database { + dbNames := accessChecker.EnumerateDatabaseNames(database) + var dbUsers []string + if res, err := accessChecker.EnumerateDatabaseUsers(database); err == nil { + dbUsers = res.Allowed() + } + uiLabels := ui.MakeLabelsWithoutInternalPrefixes(database.GetAllLabels()) db := Database{ - Kind: database.GetKind(), - Name: database.GetName(), - Desc: database.GetDescription(), - Protocol: database.GetProtocol(), - Type: database.GetType(), - Labels: uiLabels, - DatabaseUsers: dbUsers, - DatabaseNames: dbNames, - Hostname: stripProtocolAndPort(database.GetURI()), - URI: database.GetURI(), - RequiresRequest: requiresRequest, + Kind: database.GetKind(), + Name: database.GetName(), + Desc: database.GetDescription(), + Protocol: database.GetProtocol(), + Type: database.GetType(), + Labels: uiLabels, + DatabaseUsers: dbUsers, + DatabaseNames: dbNames.Allowed(), + Hostname: stripProtocolAndPort(database.GetURI()), + URI: database.GetURI(), + RequiresRequest: requiresRequest, + SupportsInteractive: interactiveChecker.IsSupported(database.GetProtocol()), } if database.IsAWSHosted() { @@ -355,10 +371,10 @@ func MakeDatabase(database types.Database, dbUsers, dbNames []string, requiresRe } // MakeDatabases creates database objects. -func MakeDatabases(databases []*types.DatabaseV3, dbUsers, dbNames []string) []Database { +func MakeDatabases(databases []*types.DatabaseV3, accessChecker services.AccessChecker, interactiveChecker DatabaseInteractiveChecker) []Database { uiServers := make([]Database, 0, len(databases)) for _, database := range databases { - db := MakeDatabase(database, dbUsers, dbNames, false /* requiresRequest */) + db := MakeDatabase(database, accessChecker, interactiveChecker, false /* requiresRequest */) uiServers = append(uiServers, db) } diff --git a/lib/web/ui/server_test.go b/lib/web/ui/server_test.go index 9af621cf09557..514e9474b0c80 100644 --- a/lib/web/ui/server_test.go +++ b/lib/web/ui/server_test.go @@ -431,7 +431,8 @@ func TestMakeDatabaseHiddenLabels(t *testing.T) { }, } - outputDb := MakeDatabase(inputDb, nil, nil, false) + accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, "clusterName", nil) + outputDb := MakeDatabase(inputDb, accessChecker, &mockDatabaseInteractiveChecker{}, false) require.Equal(t, []ui.Label{ { @@ -590,3 +591,33 @@ func TestSortedLabels(t *testing.T) { }) } } + +func TestMakeDatabaseSupportsInteractive(t *testing.T) { + db := &types.DatabaseV3{} + accessChecker := services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, "clusterName", nil) + + for name, tc := range map[string]struct { + supports bool + }{ + "supported": {supports: true}, + "unsupported": {supports: false}, + } { + t.Run(name, func(t *testing.T) { + interactiveChecker := &mockDatabaseInteractiveChecker{supports: tc.supports} + single := MakeDatabase(db, accessChecker, interactiveChecker, false) + require.Equal(t, tc.supports, single.SupportsInteractive) + + multi := MakeDatabases([]*types.DatabaseV3{db}, accessChecker, interactiveChecker) + require.Len(t, multi, 1) + require.Equal(t, tc.supports, multi[0].SupportsInteractive) + }) + } +} + +type mockDatabaseInteractiveChecker struct { + supports bool +} + +func (m *mockDatabaseInteractiveChecker) IsSupported(_ string) bool { + return m.supports +} diff --git a/lib/web/ws_io.go b/lib/web/ws_io.go index 2ed1e4c548807..63055a376706d 100644 --- a/lib/web/ws_io.go +++ b/lib/web/ws_io.go @@ -20,11 +20,11 @@ package web import ( "context" + "log/slog" "time" "github.com/gorilla/websocket" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" ) type WebsocketIO struct { @@ -71,8 +71,8 @@ type wsPinger interface { // to prevent the connection between web client and teleport proxy from becoming idle. // Interval is determined by the keep_alive_interval config set by user (or default). // Loop will terminate when there is an error sending ping frame or when the context is canceled. -func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval time.Duration, log logrus.FieldLogger, onClose func() error) { - log.Debugf("Starting websocket ping loop with interval %v.", keepAliveInterval) +func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval time.Duration, log *slog.Logger, onClose func() error) { + log.DebugContext(ctx, "Starting websocket ping loop with interval", "interval", keepAliveInterval) tickerCh := time.NewTicker(keepAliveInterval) defer tickerCh.Stop() @@ -83,16 +83,16 @@ func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval tim // If this is just a temporary issue, we will retry shortly anyway. deadline := time.Now().Add(time.Second) if err := pinger.WriteControl(websocket.PingMessage, nil, deadline); err != nil { - log.WithError(err).Error("Unable to send ping frame to web client") + log.ErrorContext(ctx, "Unable to send ping frame to web client", "error", err) if onClose != nil { if err := onClose(); err != nil { - log.WithError(err).Error("OnClose handler failed") + log.ErrorContext(ctx, "OnClose handler failed", "error", err) } } return } case <-ctx.Done(): - log.Debug("Terminating websocket ping loop.") + log.DebugContext(ctx, "Terminating websocket ping loop.") return } } diff --git a/web/packages/teleport/src/Console/Console.tsx b/web/packages/teleport/src/Console/Console.tsx index 36cf572586dbd..e0bc17f879f4d 100644 --- a/web/packages/teleport/src/Console/Console.tsx +++ b/web/packages/teleport/src/Console/Console.tsx @@ -37,6 +37,7 @@ import usePageTitle from './usePageTitle'; import useTabRouting from './useTabRouting'; import useOnExitConfirmation from './useOnExitConfirmation'; import useKeyboardNav from './useKeyboardNav'; +import { DocumentDb } from './DocumentDb'; const POLL_INTERVAL = 5000; // every 5 sec @@ -77,7 +78,9 @@ export default function Console() { return consoleCtx.refreshParties(); } - const disableNewTab = storeDocs.getNodeDocuments().length > 0; + const disableNewTab = + storeDocs.getNodeDocuments().length > 0 || + storeDocs.getDbDocuments().length > 0; const $docs = documents.map(doc => ( )); @@ -139,6 +142,8 @@ function MemoizedDocument(props: { doc: stores.Document; visible: boolean }) { return ; case 'kubeExec': return ; + case 'db': + return ; default: return ; } diff --git a/web/packages/teleport/src/Console/DocumentDb/ConnectDialog.tsx b/web/packages/teleport/src/Console/DocumentDb/ConnectDialog.tsx new file mode 100644 index 0000000000000..6d91f451fe0ae --- /dev/null +++ b/web/packages/teleport/src/Console/DocumentDb/ConnectDialog.tsx @@ -0,0 +1,219 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import React, { useCallback, useEffect, useState } from 'react'; +import Dialog, { + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, +} from 'design/Dialog'; +import { Box, ButtonPrimary, ButtonSecondary, Flex, Indicator } from 'design'; + +import Validation from 'shared/components/Validation'; +import { Option } from 'shared/components/Select'; +import { + FieldSelect, + FieldSelectCreatable, +} from 'shared/components/FieldSelect'; + +import { Danger } from 'design/Alert'; +import { requiredField } from 'shared/components/Validation/rules'; +import { useAsync } from 'shared/hooks/useAsync'; + +import { useTeleport } from 'teleport'; +import { Database } from 'teleport/services/databases'; +import { DbConnectData } from 'teleport/lib/term/tty'; + +export function ConnectDialog(props: { + clusterId: string; + serviceName: string; + onClose(): void; + onConnect(data: DbConnectData): void; +}) { + // Fetch database information to pre-fill the connection parameters. + const ctx = useTeleport(); + const [attempt, getDatabase] = useAsync( + useCallback(async () => { + const response = await ctx.resourceService.fetchUnifiedResources( + props.clusterId, + { + query: `name == "${props.serviceName}"`, + kinds: ['db'], + sort: { fieldName: 'name', dir: 'ASC' }, + limit: 1, + } + ); + + // TODO(gabrielcorado): Handle scenarios where there is conflict on the name. + if (response.agents.length !== 1 || response.agents[0].kind !== 'db') { + throw new Error('Unable to retrieve database information.'); + } + + return response.agents[0]; + }, [props.clusterId, ctx.resourceService, props.serviceName]) + ); + + useEffect(() => { + void getDatabase(); + }, [getDatabase]); + + return ( + + + Connect To Database + + + {attempt.status === 'error' && } + {(attempt.status === '' || attempt.status === 'processing') && ( + + + + )} + {attempt.status === 'success' && ( + + )} + + ); +} + +function ConnectForm(props: { + db: Database; + onConnect(data: DbConnectData): void; + onClose(): void; +}) { + const dbUserOpts = props.db.users + ?.map(user => ({ + value: user, + label: user, + })) + .filter(removeWildcardOption); + const dbNamesOpts = props.db.names + ?.map(name => ({ + value: name, + label: name, + })) + .filter(removeWildcardOption); + const dbRolesOpts = props.db.roles + ?.map(role => ({ + value: role, + label: role, + })) + .filter(removeWildcardOption); + + const [selectedName, setSelectedName] = useState