Skip to content

Commit

Permalink
Use default labels when deploying service into ECS and allow overwrit…
Browse files Browse the repository at this point in the history
…ing existing DB (web) (#45180)

* Allow overwriting an existing database

* Default to pre-defined labels when deploying service into ecs

* Address CRs

* Reuse create database endpoint to support overwrite

Revert database update logic

* Add checks for required fields for default labels
  • Loading branch information
kimlisa authored Aug 21, 2024
1 parent c6f1396 commit 1472b99
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 38 deletions.
4 changes: 2 additions & 2 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,8 @@ func (h *Handler) bindDefaultEndpoints() {

// Database access handlers.
h.GET("/webapi/sites/:site/databases", h.WithClusterAuth(h.clusterDatabasesGet))
h.POST("/webapi/sites/:site/databases", h.WithClusterAuth(h.handleDatabaseCreate))
h.PUT("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.handleDatabaseUpdate))
h.POST("/webapi/sites/:site/databases", h.WithClusterAuth(h.handleDatabaseCreateOrOverwrite))
h.PUT("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.handleDatabasePartialUpdate))
h.GET("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.clusterDatabaseGet))
h.GET("/webapi/sites/:site/databases/:database/iam/policy", h.WithClusterAuth(h.handleDatabaseGetIAMPolicy))
h.GET("/webapi/scripts/databases/configure/sqlserver/:token/configure-ad.ps1", httplib.MakeHandler(h.sqlServerConfigureADScriptHandle))
Expand Down
95 changes: 84 additions & 11 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7196,7 +7196,7 @@ func TestCreateDatabase(t *testing.T) {
createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases")

// Create an initial database to table test a duplicate creation
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createOrOverwriteDatabaseRequest{
Name: "duplicatedb",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7205,13 +7205,13 @@ func TestCreateDatabase(t *testing.T) {

for _, tt := range []struct {
name string
req createDatabaseRequest
req createOrOverwriteDatabaseRequest
expectedStatus int
errAssert require.ErrorAssertionFunc
}{
{
name: "valid",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "mydatabase",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7227,7 +7227,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "valid with labels",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "dbwithlabels",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7247,7 +7247,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "empty name",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7259,7 +7259,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "empty protocol",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "emptyprotocol",
Protocol: "",
URI: "someuri:3306",
Expand All @@ -7271,7 +7271,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "empty uri",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "emptyuri",
Protocol: "mysql",
URI: "",
Expand All @@ -7283,7 +7283,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "missing port",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "missingport",
Protocol: "mysql",
URI: "someuri",
Expand All @@ -7295,7 +7295,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "duplicatedb",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "duplicatedb",
Protocol: "mysql",
URI: "someuri:3306",
Expand Down Expand Up @@ -7352,6 +7352,79 @@ func TestCreateDatabase(t *testing.T) {
}
}

func TestOverwriteDatabase(t *testing.T) {
t.Parallel()

env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "user", nil /* roles */)

initDb, err := types.NewDatabaseV3(types.Metadata{
Name: "postgres",
}, types.DatabaseSpecV3{
Protocol: "postgres",
URI: "localhost:5432",
AWS: types.AWS{
AccountID: "123456789012",
},
})
require.NoError(t, err)

err = env.server.Auth().CreateDatabase(context.Background(), initDb)
require.NoError(t, err)

tests := []struct {
name string
req createOrOverwriteDatabaseRequest
verifyResponse func(*testing.T, *roundtrip.Response, createOrOverwriteDatabaseRequest, error)
}{
{
name: "overwrite",
req: createOrOverwriteDatabaseRequest{
Name: initDb.GetName(),
Overwrite: true,
URI: "some-other-uri:3306",
Protocol: "postgres",
},
verifyResponse: func(t *testing.T, resp *roundtrip.Response, req createOrOverwriteDatabaseRequest, err error) {
require.NoError(t, err)

var gotDb ui.Database
require.NoError(t, json.Unmarshal(resp.Bytes(), &gotDb))
require.Equal(t, req.URI, gotDb.URI)
require.Equal(t, req.Protocol, gotDb.Protocol)
require.Empty(t, req.AWSRDS)
require.Equal(t, initDb.GetName(), gotDb.Name)

backendDb, err := env.server.Auth().GetDatabase(context.Background(), req.Name)
require.NoError(t, err)
require.Equal(t, ui.MakeDatabase(backendDb, nil, nil, false), gotDb)
},
},
{
name: "overwrite error: database does not exist",
req: createOrOverwriteDatabaseRequest{
Name: "this-db-does-not-exist",
URI: "some-uri",
Protocol: "mysql",
Overwrite: true,
},
verifyResponse: func(t *testing.T, resp *roundtrip.Response, req createOrOverwriteDatabaseRequest, err error) {
require.True(t, trace.IsNotFound(err))
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases")
resp, err := pack.clt.PostJSON(context.Background(), endpoint, test.req)

test.verifyResponse(t, resp, test.req, err)
})
}
}

func TestUpdateDatabase_Errors(t *testing.T) {
t.Parallel()

Expand All @@ -7377,7 +7450,7 @@ func TestUpdateDatabase_Errors(t *testing.T) {

// Create database
createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases")
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createOrOverwriteDatabaseRequest{
Name: databaseName,
Protocol: "mysql",
URI: "someuri:3306",
Expand Down Expand Up @@ -7480,7 +7553,7 @@ func TestUpdateDatabase_NonErrors(t *testing.T) {

// Create a database.
dbProtocol := "mysql"
database, err := getNewDatabaseResource(createDatabaseRequest{
database, err := getNewDatabaseResource(createOrOverwriteDatabaseRequest{
Name: databaseName,
Protocol: dbProtocol,
URI: "someuri:3306",
Expand Down
42 changes: 28 additions & 14 deletions lib/web/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ import (
"github.com/gravitational/teleport/lib/web/ui"
)

// createDatabaseRequest contains the necessary basic information to create a database.
// Database here is the database resource, containing information to a real database (protocol, uri)
type createDatabaseRequest struct {
// createOrOverwriteDatabaseRequest contains the necessary basic information
// to create (or overwrite) a database.
// Database here is the database resource, containing information to a real
// database (protocol, uri).
type createOrOverwriteDatabaseRequest struct {
Name string `json:"name,omitempty"`
Labels []ui.Label `json:"labels,omitempty"`
Protocol string `json:"protocol,omitempty"`
URI string `json:"uri,omitempty"`
AWSRDS *awsRDS `json:"awsRds,omitempty"`
// Overwrite will replace an existing db resource
// with a new db resource. Only the name cannot
// be changed.
Overwrite bool `json:"overwrite,omitempty"`
}

type awsRDS struct {
Expand All @@ -62,7 +68,7 @@ type awsRDS struct {
VPCID string `json:"vpcId,omitempty"`
}

func (r *createDatabaseRequest) checkAndSetDefaults() error {
func (r *createOrOverwriteDatabaseRequest) checkAndSetDefaults() error {
if r.Name == "" {
return trace.BadParameter("missing database name")
}
Expand Down Expand Up @@ -94,8 +100,8 @@ func (r *createDatabaseRequest) checkAndSetDefaults() error {
}

// handleDatabaseCreate creates a database's metadata.
func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
var req *createDatabaseRequest
func (h *Handler) handleDatabaseCreateOrOverwrite(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
var req *createOrOverwriteDatabaseRequest
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -114,11 +120,20 @@ func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p
return nil, trace.Wrap(err)
}

if err := clt.CreateDatabase(r.Context(), database); err != nil {
if trace.IsAlreadyExists(err) {
return nil, trace.AlreadyExists("failed to create database (%q already exists), please use another name", req.Name)
if req.Overwrite {
if _, err := clt.GetDatabase(r.Context(), req.Name); err != nil {
return nil, trace.Wrap(err)
}
if err := clt.UpdateDatabase(r.Context(), database); err != nil {
return nil, trace.Wrap(err)
}
} else {
if err := clt.CreateDatabase(r.Context(), database); err != nil {
if trace.IsAlreadyExists(err) {
return nil, trace.AlreadyExists("failed to create database (%q already exists), please use another name", req.Name)
}
return nil, trace.Wrap(err)
}
return nil, trace.Wrap(err)
}

accessChecker, err := sctx.GetUserAccessChecker()
Expand Down Expand Up @@ -170,7 +185,7 @@ func (r *updateDatabaseRequest) checkAndSetDefaults() error {
}

// handleDatabaseUpdate updates the database
func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
func (h *Handler) handleDatabasePartialUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
databaseName := p.ByName("database")
if databaseName == "" {
return nil, trace.BadParameter("a database name is required")
Expand Down Expand Up @@ -219,7 +234,7 @@ func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p
savedLabels := database.GetStaticLabels()

// Make a new database to reset the check and set defaulted fields.
database, err = getNewDatabaseResource(createDatabaseRequest{
database, err = getNewDatabaseResource(createOrOverwriteDatabaseRequest{
Name: databaseName,
Protocol: database.GetProtocol(),
URI: savedOrNewURI,
Expand All @@ -229,7 +244,6 @@ func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p
if err != nil {
return nil, trace.Wrap(err)
}

database.SetCA(savedOrNewCaCert)
if len(req.Labels) == 0 {
database.SetStaticLabels(savedLabels)
Expand Down Expand Up @@ -397,7 +411,7 @@ func fetchDatabaseWithName(ctx context.Context, clt resourcesAPIGetter, r *http.
}
}

func getNewDatabaseResource(req createDatabaseRequest) (*types.DatabaseV3, error) {
func getNewDatabaseResource(req createOrOverwriteDatabaseRequest) (*types.DatabaseV3, error) {
labels := make(map[string]string)
for _, label := range req.Labels {
labels[label.Name] = label.Value
Expand Down
20 changes: 10 additions & 10 deletions lib/web/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {

for _, test := range []struct {
desc string
req createDatabaseRequest
req createOrOverwriteDatabaseRequest
errAssert require.ErrorAssertionFunc
}{
{
desc: "valid general",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "protocol",
URI: "uri",
Expand All @@ -61,7 +61,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "valid aws rds",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "protocol",
URI: "uri",
Expand All @@ -76,7 +76,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing name",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -88,7 +88,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing protocol",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "",
URI: "uri",
Expand All @@ -100,7 +100,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing uri",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "protocol",
URI: "",
Expand All @@ -112,7 +112,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds account id",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -129,7 +129,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds resource id",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -146,7 +146,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds subnets",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -163,7 +163,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds vpcid",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand Down
Loading

0 comments on commit 1472b99

Please sign in to comment.