Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use default labels when deploying service into ECS and allow overwriting existing DB (web) #45180

Merged
merged 5 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,8 @@ func (h *Handler) bindDefaultEndpoints() {

// Database access handlers.
h.GET("/webapi/sites/:site/databases", h.WithClusterAuth(h.clusterDatabasesGet))
h.POST("/webapi/sites/:site/databases", h.WithClusterAuth(h.handleDatabaseCreate))
h.PUT("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.handleDatabaseUpdate))
h.POST("/webapi/sites/:site/databases", h.WithClusterAuth(h.handleDatabaseCreateOrOverwrite))
h.PUT("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.handleDatabasePartialUpdate))
h.GET("/webapi/sites/:site/databases/:database", h.WithClusterAuth(h.clusterDatabaseGet))
h.GET("/webapi/sites/:site/databases/:database/iam/policy", h.WithClusterAuth(h.handleDatabaseGetIAMPolicy))
h.GET("/webapi/scripts/databases/configure/sqlserver/:token/configure-ad.ps1", httplib.MakeHandler(h.sqlServerConfigureADScriptHandle))
Expand Down
95 changes: 84 additions & 11 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7196,7 +7196,7 @@ func TestCreateDatabase(t *testing.T) {
createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases")

// Create an initial database to table test a duplicate creation
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createOrOverwriteDatabaseRequest{
Name: "duplicatedb",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7205,13 +7205,13 @@ func TestCreateDatabase(t *testing.T) {

for _, tt := range []struct {
name string
req createDatabaseRequest
req createOrOverwriteDatabaseRequest
expectedStatus int
errAssert require.ErrorAssertionFunc
}{
{
name: "valid",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "mydatabase",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7227,7 +7227,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "valid with labels",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "dbwithlabels",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7247,7 +7247,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "empty name",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "mysql",
URI: "someuri:3306",
Expand All @@ -7259,7 +7259,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "empty protocol",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "emptyprotocol",
Protocol: "",
URI: "someuri:3306",
Expand All @@ -7271,7 +7271,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "empty uri",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "emptyuri",
Protocol: "mysql",
URI: "",
Expand All @@ -7283,7 +7283,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "missing port",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "missingport",
Protocol: "mysql",
URI: "someuri",
Expand All @@ -7295,7 +7295,7 @@ func TestCreateDatabase(t *testing.T) {
},
{
name: "duplicatedb",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "duplicatedb",
Protocol: "mysql",
URI: "someuri:3306",
Expand Down Expand Up @@ -7352,6 +7352,79 @@ func TestCreateDatabase(t *testing.T) {
}
}

func TestOverwriteDatabase(t *testing.T) {
t.Parallel()

env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "user", nil /* roles */)

initDb, err := types.NewDatabaseV3(types.Metadata{
Name: "postgres",
}, types.DatabaseSpecV3{
Protocol: "postgres",
URI: "localhost:5432",
AWS: types.AWS{
AccountID: "123456789012",
},
})
require.NoError(t, err)

err = env.server.Auth().CreateDatabase(context.Background(), initDb)
require.NoError(t, err)

tests := []struct {
name string
req createOrOverwriteDatabaseRequest
verifyResponse func(*testing.T, *roundtrip.Response, createOrOverwriteDatabaseRequest, error)
}{
{
name: "overwrite",
req: createOrOverwriteDatabaseRequest{
Name: initDb.GetName(),
Overwrite: true,
URI: "some-other-uri:3306",
Protocol: "postgres",
},
verifyResponse: func(t *testing.T, resp *roundtrip.Response, req createOrOverwriteDatabaseRequest, err error) {
require.NoError(t, err)

var gotDb ui.Database
require.NoError(t, json.Unmarshal(resp.Bytes(), &gotDb))
require.Equal(t, req.URI, gotDb.URI)
require.Equal(t, req.Protocol, gotDb.Protocol)
require.Empty(t, req.AWSRDS)
require.Equal(t, initDb.GetName(), gotDb.Name)

backendDb, err := env.server.Auth().GetDatabase(context.Background(), req.Name)
require.NoError(t, err)
require.Equal(t, ui.MakeDatabase(backendDb, nil, nil, false), gotDb)
},
},
{
name: "overwrite error: database does not exist",
req: createOrOverwriteDatabaseRequest{
Name: "this-db-does-not-exist",
URI: "some-uri",
Protocol: "mysql",
Overwrite: true,
},
verifyResponse: func(t *testing.T, resp *roundtrip.Response, req createOrOverwriteDatabaseRequest, err error) {
require.True(t, trace.IsNotFound(err))
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases")
resp, err := pack.clt.PostJSON(context.Background(), endpoint, test.req)

test.verifyResponse(t, resp, test.req, err)
})
}
}

func TestUpdateDatabase_Errors(t *testing.T) {
t.Parallel()

Expand All @@ -7377,7 +7450,7 @@ func TestUpdateDatabase_Errors(t *testing.T) {

// Create database
createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases")
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{
_, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createOrOverwriteDatabaseRequest{
Name: databaseName,
Protocol: "mysql",
URI: "someuri:3306",
Expand Down Expand Up @@ -7480,7 +7553,7 @@ func TestUpdateDatabase_NonErrors(t *testing.T) {

// Create a database.
dbProtocol := "mysql"
database, err := getNewDatabaseResource(createDatabaseRequest{
database, err := getNewDatabaseResource(createOrOverwriteDatabaseRequest{
Name: databaseName,
Protocol: dbProtocol,
URI: "someuri:3306",
Expand Down
42 changes: 28 additions & 14 deletions lib/web/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ import (
"github.com/gravitational/teleport/lib/web/ui"
)

// createDatabaseRequest contains the necessary basic information to create a database.
// Database here is the database resource, containing information to a real database (protocol, uri)
type createDatabaseRequest struct {
// createOrOverwriteDatabaseRequest contains the necessary basic information
// to create (or overwrite) a database.
// Database here is the database resource, containing information to a real
// database (protocol, uri).
type createOrOverwriteDatabaseRequest struct {
Name string `json:"name,omitempty"`
Labels []ui.Label `json:"labels,omitempty"`
Protocol string `json:"protocol,omitempty"`
URI string `json:"uri,omitempty"`
AWSRDS *awsRDS `json:"awsRds,omitempty"`
// Overwrite will replace an existing db resource
// with a new db resource. Only the name cannot
// be changed.
Overwrite bool `json:"overwrite,omitempty"`
}

type awsRDS struct {
Expand All @@ -62,7 +68,7 @@ type awsRDS struct {
VPCID string `json:"vpcId,omitempty"`
}

func (r *createDatabaseRequest) checkAndSetDefaults() error {
func (r *createOrOverwriteDatabaseRequest) checkAndSetDefaults() error {
if r.Name == "" {
return trace.BadParameter("missing database name")
}
Expand Down Expand Up @@ -94,8 +100,8 @@ func (r *createDatabaseRequest) checkAndSetDefaults() error {
}

// handleDatabaseCreate creates a database's metadata.
func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
var req *createDatabaseRequest
func (h *Handler) handleDatabaseCreateOrOverwrite(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
var req *createOrOverwriteDatabaseRequest
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -114,11 +120,20 @@ func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p
return nil, trace.Wrap(err)
}

if err := clt.CreateDatabase(r.Context(), database); err != nil {
if trace.IsAlreadyExists(err) {
return nil, trace.AlreadyExists("failed to create database (%q already exists), please use another name", req.Name)
if req.Overwrite {
if _, err := clt.GetDatabase(r.Context(), req.Name); err != nil {
return nil, trace.Wrap(err)
}
if err := clt.UpdateDatabase(r.Context(), database); err != nil {
return nil, trace.Wrap(err)
}
} else {
if err := clt.CreateDatabase(r.Context(), database); err != nil {
if trace.IsAlreadyExists(err) {
return nil, trace.AlreadyExists("failed to create database (%q already exists), please use another name", req.Name)
}
return nil, trace.Wrap(err)
}
return nil, trace.Wrap(err)
}

accessChecker, err := sctx.GetUserAccessChecker()
Expand Down Expand Up @@ -170,7 +185,7 @@ func (r *updateDatabaseRequest) checkAndSetDefaults() error {
}

// handleDatabaseUpdate updates the database
func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
func (h *Handler) handleDatabasePartialUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) {
databaseName := p.ByName("database")
if databaseName == "" {
return nil, trace.BadParameter("a database name is required")
Expand Down Expand Up @@ -219,7 +234,7 @@ func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p
savedLabels := database.GetStaticLabels()

// Make a new database to reset the check and set defaulted fields.
database, err = getNewDatabaseResource(createDatabaseRequest{
database, err = getNewDatabaseResource(createOrOverwriteDatabaseRequest{
Name: databaseName,
Protocol: database.GetProtocol(),
URI: savedOrNewURI,
Expand All @@ -229,7 +244,6 @@ func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p
if err != nil {
return nil, trace.Wrap(err)
}

database.SetCA(savedOrNewCaCert)
if len(req.Labels) == 0 {
database.SetStaticLabels(savedLabels)
Expand Down Expand Up @@ -397,7 +411,7 @@ func fetchDatabaseWithName(ctx context.Context, clt resourcesAPIGetter, r *http.
}
}

func getNewDatabaseResource(req createDatabaseRequest) (*types.DatabaseV3, error) {
func getNewDatabaseResource(req createOrOverwriteDatabaseRequest) (*types.DatabaseV3, error) {
labels := make(map[string]string)
for _, label := range req.Labels {
labels[label.Name] = label.Value
Expand Down
20 changes: 10 additions & 10 deletions lib/web/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {

for _, test := range []struct {
desc string
req createDatabaseRequest
req createOrOverwriteDatabaseRequest
errAssert require.ErrorAssertionFunc
}{
{
desc: "valid general",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "protocol",
URI: "uri",
Expand All @@ -61,7 +61,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "valid aws rds",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "protocol",
URI: "uri",
Expand All @@ -76,7 +76,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing name",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -88,7 +88,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing protocol",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "",
URI: "uri",
Expand All @@ -100,7 +100,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing uri",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "name",
Protocol: "protocol",
URI: "",
Expand All @@ -112,7 +112,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds account id",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -129,7 +129,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds resource id",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -146,7 +146,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds subnets",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand All @@ -163,7 +163,7 @@ func TestCreateDatabaseRequestParameters(t *testing.T) {
},
{
desc: "invalid missing aws rds vpcid",
req: createDatabaseRequest{
req: createOrOverwriteDatabaseRequest{
Name: "",
Protocol: "protocol",
URI: "uri",
Expand Down
Loading
Loading