diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 8d2a449b6f6e9..aca8dfe495bcf 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -795,8 +795,8 @@ func (h *Handler) bindDefaultEndpoints() { // Database access handlers. h.GET("/webapi/sites/:site/databases", h.WithClusterAuth(h.clusterDatabasesGet)) - h.POST("/webapi/sites/:site/databases", h.WithClusterAuth(h.handleDatabaseCreate)) - h.PUT("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.handleDatabaseUpdate)) + h.POST("/webapi/sites/:site/databases", h.WithClusterAuth(h.handleDatabaseCreateOrOverwrite)) + h.PUT("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.handleDatabasePartialUpdate)) h.GET("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.clusterDatabaseGet)) h.GET("/webapi/sites/:site/databases/:database/iam/policy", h.WithClusterAuth(h.handleDatabaseGetIAMPolicy)) h.GET("/webapi/scripts/databases/configure/sqlserver/:token/configure-ad.ps1", httplib.MakeHandler(h.sqlServerConfigureADScriptHandle)) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index d418ce193b36e..45ad30e4ef937 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7196,7 +7196,7 @@ func TestCreateDatabase(t *testing.T) { createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases") // Create an initial database to table test a duplicate creation - _, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{ + _, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createOrOverwriteDatabaseRequest{ Name: "duplicatedb", Protocol: "mysql", URI: "someuri:3306", @@ -7205,13 +7205,13 @@ func TestCreateDatabase(t *testing.T) { for _, tt := range []struct { name string - req createDatabaseRequest + req createOrOverwriteDatabaseRequest expectedStatus int errAssert require.ErrorAssertionFunc }{ { name: "valid", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "mydatabase", Protocol: "mysql", URI: "someuri:3306", @@ -7227,7 +7227,7 @@ func TestCreateDatabase(t *testing.T) { }, { name: "valid with labels", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "dbwithlabels", Protocol: "mysql", URI: "someuri:3306", @@ -7247,7 +7247,7 @@ func TestCreateDatabase(t *testing.T) { }, { name: "empty name", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "", Protocol: "mysql", URI: "someuri:3306", @@ -7259,7 +7259,7 @@ func TestCreateDatabase(t *testing.T) { }, { name: "empty protocol", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "emptyprotocol", Protocol: "", URI: "someuri:3306", @@ -7271,7 +7271,7 @@ func TestCreateDatabase(t *testing.T) { }, { name: "empty uri", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "emptyuri", Protocol: "mysql", URI: "", @@ -7283,7 +7283,7 @@ func TestCreateDatabase(t *testing.T) { }, { name: "missing port", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "missingport", Protocol: "mysql", URI: "someuri", @@ -7295,7 +7295,7 @@ func TestCreateDatabase(t *testing.T) { }, { name: "duplicatedb", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "duplicatedb", Protocol: "mysql", URI: "someuri:3306", @@ -7352,6 +7352,79 @@ func TestCreateDatabase(t *testing.T) { } } +func TestOverwriteDatabase(t *testing.T) { + t.Parallel() + + env := newWebPack(t, 1) + proxy := env.proxies[0] + pack := proxy.authPack(t, "user", nil /* roles */) + + initDb, err := types.NewDatabaseV3(types.Metadata{ + Name: "postgres", + }, types.DatabaseSpecV3{ + Protocol: "postgres", + URI: "localhost:5432", + AWS: types.AWS{ + AccountID: "123456789012", + }, + }) + require.NoError(t, err) + + err = env.server.Auth().CreateDatabase(context.Background(), initDb) + require.NoError(t, err) + + tests := []struct { + name string + req createOrOverwriteDatabaseRequest + verifyResponse func(*testing.T, *roundtrip.Response, createOrOverwriteDatabaseRequest, error) + }{ + { + name: "overwrite", + req: createOrOverwriteDatabaseRequest{ + Name: initDb.GetName(), + Overwrite: true, + URI: "some-other-uri:3306", + Protocol: "postgres", + }, + verifyResponse: func(t *testing.T, resp *roundtrip.Response, req createOrOverwriteDatabaseRequest, err error) { + require.NoError(t, err) + + var gotDb ui.Database + require.NoError(t, json.Unmarshal(resp.Bytes(), &gotDb)) + require.Equal(t, req.URI, gotDb.URI) + require.Equal(t, req.Protocol, gotDb.Protocol) + require.Empty(t, req.AWSRDS) + require.Equal(t, initDb.GetName(), gotDb.Name) + + backendDb, err := env.server.Auth().GetDatabase(context.Background(), req.Name) + require.NoError(t, err) + require.Equal(t, ui.MakeDatabase(backendDb, nil, nil, false), gotDb) + }, + }, + { + name: "overwrite error: database does not exist", + req: createOrOverwriteDatabaseRequest{ + Name: "this-db-does-not-exist", + URI: "some-uri", + Protocol: "mysql", + Overwrite: true, + }, + verifyResponse: func(t *testing.T, resp *roundtrip.Response, req createOrOverwriteDatabaseRequest, err error) { + require.True(t, trace.IsNotFound(err)) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases") + resp, err := pack.clt.PostJSON(context.Background(), endpoint, test.req) + + test.verifyResponse(t, resp, test.req, err) + }) + } +} + func TestUpdateDatabase_Errors(t *testing.T) { t.Parallel() @@ -7377,7 +7450,7 @@ func TestUpdateDatabase_Errors(t *testing.T) { // Create database createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases") - _, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{ + _, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createOrOverwriteDatabaseRequest{ Name: databaseName, Protocol: "mysql", URI: "someuri:3306", @@ -7480,7 +7553,7 @@ func TestUpdateDatabase_NonErrors(t *testing.T) { // Create a database. dbProtocol := "mysql" - database, err := getNewDatabaseResource(createDatabaseRequest{ + database, err := getNewDatabaseResource(createOrOverwriteDatabaseRequest{ Name: databaseName, Protocol: dbProtocol, URI: "someuri:3306", diff --git a/lib/web/databases.go b/lib/web/databases.go index 1d84f5c476937..56b3fa6ef3363 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -45,14 +45,20 @@ import ( "github.com/gravitational/teleport/lib/web/ui" ) -// createDatabaseRequest contains the necessary basic information to create a database. -// Database here is the database resource, containing information to a real database (protocol, uri) -type createDatabaseRequest struct { +// createOrOverwriteDatabaseRequest contains the necessary basic information +// to create (or overwrite) a database. +// Database here is the database resource, containing information to a real +// database (protocol, uri). +type createOrOverwriteDatabaseRequest struct { Name string `json:"name,omitempty"` Labels []ui.Label `json:"labels,omitempty"` Protocol string `json:"protocol,omitempty"` URI string `json:"uri,omitempty"` AWSRDS *awsRDS `json:"awsRds,omitempty"` + // Overwrite will replace an existing db resource + // with a new db resource. Only the name cannot + // be changed. + Overwrite bool `json:"overwrite,omitempty"` } type awsRDS struct { @@ -62,7 +68,7 @@ type awsRDS struct { VPCID string `json:"vpcId,omitempty"` } -func (r *createDatabaseRequest) checkAndSetDefaults() error { +func (r *createOrOverwriteDatabaseRequest) checkAndSetDefaults() error { if r.Name == "" { return trace.BadParameter("missing database name") } @@ -94,8 +100,8 @@ func (r *createDatabaseRequest) checkAndSetDefaults() error { } // handleDatabaseCreate creates a database's metadata. -func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { - var req *createDatabaseRequest +func (h *Handler) handleDatabaseCreateOrOverwrite(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { + var req *createOrOverwriteDatabaseRequest if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } @@ -114,11 +120,20 @@ func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p return nil, trace.Wrap(err) } - if err := clt.CreateDatabase(r.Context(), database); err != nil { - if trace.IsAlreadyExists(err) { - return nil, trace.AlreadyExists("failed to create database (%q already exists), please use another name", req.Name) + if req.Overwrite { + if _, err := clt.GetDatabase(r.Context(), req.Name); err != nil { + return nil, trace.Wrap(err) + } + if err := clt.UpdateDatabase(r.Context(), database); err != nil { + return nil, trace.Wrap(err) + } + } else { + if err := clt.CreateDatabase(r.Context(), database); err != nil { + if trace.IsAlreadyExists(err) { + return nil, trace.AlreadyExists("failed to create database (%q already exists), please use another name", req.Name) + } + return nil, trace.Wrap(err) } - return nil, trace.Wrap(err) } accessChecker, err := sctx.GetUserAccessChecker() @@ -170,7 +185,7 @@ func (r *updateDatabaseRequest) checkAndSetDefaults() error { } // handleDatabaseUpdate updates the database -func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { +func (h *Handler) handleDatabasePartialUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { databaseName := p.ByName("database") if databaseName == "" { return nil, trace.BadParameter("a database name is required") @@ -219,7 +234,7 @@ func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p savedLabels := database.GetStaticLabels() // Make a new database to reset the check and set defaulted fields. - database, err = getNewDatabaseResource(createDatabaseRequest{ + database, err = getNewDatabaseResource(createOrOverwriteDatabaseRequest{ Name: databaseName, Protocol: database.GetProtocol(), URI: savedOrNewURI, @@ -229,7 +244,6 @@ func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p if err != nil { return nil, trace.Wrap(err) } - database.SetCA(savedOrNewCaCert) if len(req.Labels) == 0 { database.SetStaticLabels(savedLabels) @@ -397,7 +411,7 @@ func fetchDatabaseWithName(ctx context.Context, clt resourcesAPIGetter, r *http. } } -func getNewDatabaseResource(req createDatabaseRequest) (*types.DatabaseV3, error) { +func getNewDatabaseResource(req createOrOverwriteDatabaseRequest) (*types.DatabaseV3, error) { labels := make(map[string]string) for _, label := range req.Labels { labels[label.Name] = label.Value diff --git a/lib/web/databases_test.go b/lib/web/databases_test.go index 1fe3f66470b94..3a0dd78d212d0 100644 --- a/lib/web/databases_test.go +++ b/lib/web/databases_test.go @@ -47,12 +47,12 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { for _, test := range []struct { desc string - req createDatabaseRequest + req createOrOverwriteDatabaseRequest errAssert require.ErrorAssertionFunc }{ { desc: "valid general", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "name", Protocol: "protocol", URI: "uri", @@ -61,7 +61,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "valid aws rds", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "name", Protocol: "protocol", URI: "uri", @@ -76,7 +76,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing name", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "", Protocol: "protocol", URI: "uri", @@ -88,7 +88,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing protocol", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "name", Protocol: "", URI: "uri", @@ -100,7 +100,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing uri", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "name", Protocol: "protocol", URI: "", @@ -112,7 +112,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing aws rds account id", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "", Protocol: "protocol", URI: "uri", @@ -129,7 +129,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing aws rds resource id", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "", Protocol: "protocol", URI: "uri", @@ -146,7 +146,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing aws rds subnets", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "", Protocol: "protocol", URI: "uri", @@ -163,7 +163,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) { }, { desc: "invalid missing aws rds vpcid", - req: createDatabaseRequest{ + req: createOrOverwriteDatabaseRequest{ Name: "", Protocol: "protocol", URI: "uri", diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 888882579f8dc..91764499f3158 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -110,11 +110,33 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p return nil, trace.Wrap(err) } - databaseAgentMatcherLabels := make(types.Labels, len(req.DatabaseAgentMatcherLabels)) + databaseAgentMatcherLabels := make(types.Labels, len(req.DatabaseAgentMatcherLabels)+3) for _, label := range req.DatabaseAgentMatcherLabels { databaseAgentMatcherLabels[label.Name] = utils.Strings{label.Value} } + // DELETE in 19.0: delete only the outer if block (checking labels == 0). + // The outer block is required since older UI's will not + // send these values to the backend, but instead send custom labels (the UI + // will require at least one label before proceeding). + // Newer UI's will not send any labels, but instead send the required + // fields for default labels. + if len(req.DatabaseAgentMatcherLabels) == 0 { + if req.VPCID == "" { + return nil, trace.BadParameter("vpc ID is required") + } + if req.Region == "" { + return nil, trace.BadParameter("AWS region is required") + } + if req.AccountID == "" { + return nil, trace.BadParameter("AWS account ID is required") + } + // Add default labels. + databaseAgentMatcherLabels[types.DiscoveryLabelVPCID] = []string{req.VPCID} + databaseAgentMatcherLabels[types.DiscoveryLabelRegion] = []string{req.Region} + databaseAgentMatcherLabels[types.DiscoveryLabelAccountID] = []string{req.AccountID} + } + iamTokenName := deployserviceconfig.DefaultTeleportIAMTokenName teleportConfigString, err := deployserviceconfig.GenerateTeleportConfigString( h.PublicProxyAddr(), diff --git a/lib/web/ui/integration.go b/lib/web/ui/integration.go index fe0a3c305f21c..30543391d8d64 100644 --- a/lib/web/ui/integration.go +++ b/lib/web/ui/integration.go @@ -187,6 +187,9 @@ type AWSOIDCDeployServiceRequest struct { // Region is the AWS Region for the Service. Region string `json:"region"` + // VPCID is the VPCID where the service is going to be deployed. + VPCID string `json:"vpcId"` + // AccountID is the AWS Account ID. // Optional. sts.GetCallerIdentity is used if the value is not provided. AccountID string `json:"accountId"`