Skip to content

Commit

Permalink
Allow passing an optional ARN when health checking an AWSOIDC Integra…
Browse files Browse the repository at this point in the history
…tion (#48031)
  • Loading branch information
marcoandredinis authored Oct 29, 2024
1 parent 10d6e87 commit 366b614
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 147 deletions.
199 changes: 106 additions & 93 deletions api/gen/proto/go/teleport/integration/v1/awsoidc_service.pb.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion api/proto/teleport/integration/v1/awsoidc_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,13 @@ message ListEKSClustersResponse {
// PingRequest is a request for doing an health check against the configured integration.
message PingRequest {
// Integration is the AWS OIDC Integration name.
// Required.
// Required if ARN is empty.
string integration = 1;

// The AWS Role ARN to be used when generating the token.
// This is used to test another ARN before saving the Integration.
// Required if integration is empty.
string role_arn = 2;
}

// PingResponse contains the response for the Ping operation.
Expand Down
52 changes: 35 additions & 17 deletions lib/auth/integration/integrationv1/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,48 @@ func NewAWSOIDCService(cfg *AWSOIDCServiceConfig) (*AWSOIDCService, error) {

var _ integrationpb.AWSOIDCServiceServer = (*AWSOIDCService)(nil)

func (s *AWSOIDCService) awsClientReq(ctx context.Context, integrationName, region string) (*awsoidc.AWSClientRequest, error) {
func (s *AWSOIDCService) roleARNForIntegration(ctx context.Context, integrationName string) (string, error) {
integration, err := s.integrationService.GetIntegration(ctx, &integrationpb.GetIntegrationRequest{
Name: integrationName,
})
if err != nil {
return nil, trace.Wrap(err)
return "", trace.Wrap(err)
}

if integration.GetSubKind() != types.IntegrationSubKindAWSOIDC {
return nil, trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
return "", trace.BadParameter("integration subkind (%s) mismatch", integration.GetSubKind())
}

if integration.GetAWSOIDCIntegrationSpec() == nil {
return nil, trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
return "", trace.BadParameter("missing spec fields for %q (%q) integration", integration.GetName(), integration.GetSubKind())
}

return integration.GetAWSOIDCIntegrationSpec().RoleARN, nil
}

func (s *AWSOIDCService) awsClientReqWithARN(ctx context.Context, integrationName, region, arn string) (*awsoidc.AWSClientRequest, error) {
token, err := s.integrationService.generateAWSOIDCTokenWithoutAuthZ(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}

return &awsoidc.AWSClientRequest{
IntegrationName: integrationName,
Token: token.Token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: region,
Token: token.Token,
RoleARN: arn,
Region: region,
}, nil
}

func (s *AWSOIDCService) awsClientReq(ctx context.Context, integrationName, region string) (*awsoidc.AWSClientRequest, error) {
roleARN, err := s.roleARNForIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}

return s.awsClientReqWithARN(ctx, integrationName, region, roleARN)

}

// ListEICE returns a paginated list of EC2 Instance Connect Endpoints.
func (s *AWSOIDCService) ListEICE(ctx context.Context, req *integrationpb.ListEICERequest) (*integrationpb.ListEICEResponse, error) {
authCtx, err := s.authorizer.Authorize(ctx)
Expand Down Expand Up @@ -793,15 +806,20 @@ func (s *AWSOIDCService) Ping(ctx context.Context, req *integrationpb.PingReques
return nil, trace.Wrap(err)
}

if req.Integration == "" {
return nil, trace.BadParameter("integration is required")
}

// Instead of asking the user for a region (or storing a default region), we use the sentinel value for the global region.
// This improves the UX, because it is one less input we require from the user.
awsClientReq, err := s.awsClientReq(ctx, req.Integration, awsutils.AWSGlobalRegion)
if err != nil {
return nil, trace.Wrap(err)
var awsClientReq *awsoidc.AWSClientRequest
switch {
case req.GetRoleArn() != "":
awsClientReq, err = s.awsClientReqWithARN(ctx, req.Integration, awsutils.AWSGlobalRegion, req.GetRoleArn())
if err != nil {
return nil, trace.Wrap(err)
}
case req.GetIntegration() != "":
awsClientReq, err = s.awsClientReq(ctx, req.GetIntegration(), awsutils.AWSGlobalRegion)
if err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.BadParameter("one of arn and integration is required")
}

awsClient, err := awsoidc.NewPingClient(ctx, awsClientReq)
Expand Down
10 changes: 10 additions & 0 deletions lib/auth/integration/integrationv1/awsoidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,16 @@ func TestRBAC(t *testing.T) {
return err
},
},
{
name: "Ping with arn",
fn: func() error {
_, err := awsoidService.Ping(userCtx, &integrationv1.PingRequest{
Integration: integrationName,
RoleArn: "some-arn",
})
return err
},
},
} {
t.Run(tt.name, func(t *testing.T) {
err := tt.fn()
Expand Down
7 changes: 0 additions & 7 deletions lib/integrations/awsoidc/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ import (

// AWSClientRequest contains the required fields to set up an AWS service client.
type AWSClientRequest struct {
// IntegrationName is the integration name that is going to issue an API Call.
IntegrationName string

// Token is the token used to issue the API Call.
Token string

Expand All @@ -55,10 +52,6 @@ type AWSClientRequest struct {

// CheckAndSetDefaults checks if the required fields are present.
func (req *AWSClientRequest) CheckAndSetDefaults() error {
if req.IntegrationName == "" {
return trace.BadParameter("integration name is required")
}

if req.Token == "" {
return trace.BadParameter("token is required")
}
Expand Down
21 changes: 9 additions & 12 deletions lib/integrations/awsoidc/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,27 @@ import (
func TestCheckAndSetDefaults(t *testing.T) {
t.Run("invalid regions must return an error", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "?",
Token: "token",
RoleARN: "some-arn",
Region: "?",
}).CheckAndSetDefaults()

require.True(t, trace.IsBadParameter(err))
})
t.Run("valid region", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "us-east-1",
Token: "token",
RoleARN: "some-arn",
Region: "us-east-1",
}).CheckAndSetDefaults()
require.NoError(t, err)
})

t.Run("empty region", func(t *testing.T) {
err := (&AWSClientRequest{
IntegrationName: "my-integration",
Token: "token",
RoleARN: "some-arn",
Region: "",
Token: "token",
RoleARN: "some-arn",
Region: "",
}).CheckAndSetDefaults()
require.NoError(t, err)
})
Expand Down
9 changes: 4 additions & 5 deletions lib/integrations/awsoidc/deployservice_vcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ func TestDeployDBService(t *testing.T) {
return &AWSClientRequest{
// To record new fixtures you will need a valid token.
// You can get one by getting the generated token in a real cluster.
Token: awsOIDCToken,
RoleARN: awsOIDCRoleARN,
Region: awsRegion,
IntegrationName: integrationName,
httpClient: httpClient,
Token: awsOIDCToken,
RoleARN: awsOIDCRoleARN,
Region: awsRegion,
httpClient: httpClient,
}
}

Expand Down
2 changes: 2 additions & 0 deletions lib/integrations/awsoidc/token_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type KeyStoreManager interface {
// GenerateAWSOIDCTokenRequest contains the required elements to generate an AWS OIDC Token (JWT).
type GenerateAWSOIDCTokenRequest struct {
// Integration is the AWS OIDC Integration name.
// This field is only used to obtain custom Issuers (those stored at S3 buckets).
// If empty, the default issuer for the cluster (its public endpoint URL) will be used.
Integration string
// Username is the JWT Username (on behalf of claim)
Username string
Expand Down
7 changes: 3 additions & 4 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,9 @@ func (s *localSite) setupTunnelForOpenSSHEICENode(ctx context.Context, targetSer
}

openTunnelClt, err := awsoidc.NewOpenTunnelEC2Client(ctx, &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
})
if err != nil {
return nil, trace.BadParameter("failed to create the ec2 open tunnel client: %v", err)
Expand Down
7 changes: 3 additions & 4 deletions lib/service/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,9 @@ func (updater *AWSOIDCDeployServiceUpdater) updateAWSOIDCDeployService(ctx conte
}

req := &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsRegion,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsRegion,
}

// The deploy service client is initialized using AWS OIDC integration.
Expand Down
7 changes: 3 additions & 4 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,9 @@ func (s *Server) sendSSHPublicKeyToTarget(ctx context.Context) (ssh.Signer, erro
}

sendSSHClient, err := awsoidc.NewEICESendSSHPublicKeyClient(ctx, &awsoidc.AWSClientRequest{
IntegrationName: integration.GetName(),
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: awsInfo.Region,
})
if err != nil {
return nil, trace.BadParameter("failed to create an aws client to send ssh public key: %v", err)
Expand Down
7 changes: 7 additions & 0 deletions lib/web/integrations_awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ func getServiceURLs(dbServices []types.DatabaseService, accountID, region, telep
}

// awsOIDCPing performs an health check for the integration.
// If ARN is present in the request body, that's the ARN that will be used instead of using the one stored in the integration.
// Returns meta information: account id and assumed the ARN for the IAM Role.
func (h *Handler) awsOIDCPing(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (any, error) {
ctx := r.Context()
Expand All @@ -1440,13 +1441,19 @@ func (h *Handler) awsOIDCPing(w http.ResponseWriter, r *http.Request, p httprout
return nil, trace.BadParameter("an integration name is required")
}

var req ui.AWSOIDCPingRequest
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}

clt, err := sctx.GetUserClient(ctx, site)
if err != nil {
return nil, trace.Wrap(err)
}

pingResp, err := clt.IntegrationAWSOIDCClient().Ping(ctx, &integrationv1.PingRequest{
Integration: integrationName,
RoleArn: req.RoleARN,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
8 changes: 8 additions & 0 deletions lib/web/ui/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,11 @@ type AWSOIDCPingResponse struct {
// UserID is the unique identifier of the calling entity.
UserID string `json:"userId"`
}

// AWSOIDCPingRequest contains ping request fields.
type AWSOIDCPingRequest struct {
// RoleARN is optional, and used for cases such as
// pinging to check validity before upserting an
// AWS OIDC integration.
RoleARN string `json:"roleArn,omitempty"`
}

0 comments on commit 366b614

Please sign in to comment.