diff --git a/lib/auth/join.go b/lib/auth/join.go index 5aea611bafdea..6cb33b176a1d9 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -207,14 +207,12 @@ func (a *Server) handleJoinFailure( // will be checked. func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (certs *proto.Certs, err error) { attrs := &workloadidentityv1pb.JoinAttrs{} - // rawJoinAttrs typically holds the raw metadata sourced from a join. - // E.g the claims from a JWT token. This is used for auditing purposes. - var rawJoinAttrs any + var rawClaims any var provisionToken types.ProvisionToken defer func() { // Emit a log message and audit event on join failure. if err != nil { - a.handleJoinFailure(err, provisionToken, rawJoinAttrs, req) + a.handleJoinFailure(err, provisionToken, rawClaims, req) } }() @@ -236,7 +234,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGitHub: claims, err := a.checkGitHubJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims attrs.Github = claims.JoinAttrs() } if err != nil { @@ -245,7 +243,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGitLab: claims, err := a.checkGitLabJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims attrs.Gitlab = claims.JoinAttrs() } if err != nil { @@ -254,7 +252,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodCircleCI: claims, err := a.checkCircleCIJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims attrs.Circleci = claims.JoinAttrs() } if err != nil { @@ -263,7 +261,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodKubernetes: claims, err := a.checkKubernetesJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims } if err != nil { return nil, trace.Wrap(err) @@ -271,7 +269,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGCP: claims, err := a.checkGCPJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims } if err != nil { return nil, trace.Wrap(err) @@ -279,7 +277,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodSpacelift: claims, err := a.checkSpaceliftJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims + attrs.Spacelift = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -287,7 +286,8 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodTerraformCloud: claims, err := a.checkTerraformCloudJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims + attrs.TerraformCloud = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -295,7 +295,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodBitbucket: claims, err := a.checkBitbucketJoinRequest(ctx, req) if claims != nil { - rawJoinAttrs = claims + rawClaims = claims attrs.Bitbucket = claims.JoinAttrs() } if err != nil { @@ -323,12 +323,12 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin ctx, provisionToken, req, - rawJoinAttrs, + rawClaims, attrs, ) return certs, trace.Wrap(err) } - certs, err = a.generateCerts(ctx, provisionToken, req, rawJoinAttrs) + certs, err = a.generateCerts(ctx, provisionToken, req, rawClaims) return certs, trace.Wrap(err) } diff --git a/lib/spacelift/spacelift.go b/lib/spacelift/spacelift.go index ddaba2f11cfd2..289e074fcb3b0 100644 --- a/lib/spacelift/spacelift.go +++ b/lib/spacelift/spacelift.go @@ -21,6 +21,8 @@ package spacelift import ( "github.com/gravitational/trace" "github.com/mitchellh/mapstructure" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // IDTokenClaims @@ -49,6 +51,21 @@ type IDTokenClaims struct { Scope string `json:"scope"` } +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsSpacelift { + return &workloadidentityv1pb.JoinAttrsSpacelift{ + Sub: c.Sub, + SpaceId: c.SpaceID, + CallerType: c.CallerType, + CallerId: c.CallerID, + RunType: c.RunType, + RunId: c.RunID, + Scope: c.Scope, + } +} + // JoinAuditAttributes returns a series of attributes that can be inserted into // audit events related to a specific join. func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { diff --git a/lib/terraformcloud/terraform.go b/lib/terraformcloud/terraform.go index ded2340c2e5d1..c9db802130ae2 100644 --- a/lib/terraformcloud/terraform.go +++ b/lib/terraformcloud/terraform.go @@ -19,8 +19,7 @@ package terraformcloud import ( - "github.com/gravitational/trace" - "github.com/mitchellh/mapstructure" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // IDTokenClaims @@ -52,20 +51,17 @@ type IDTokenClaims struct { RunPhase string `json:"terraform_run_phase"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *IDTokenClaims) JoinAuditAttributes() (map[string]interface{}, error) { - res := map[string]interface{}{} - d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: &res, - }) - if err != nil { - return nil, trace.Wrap(err) +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *IDTokenClaims) JoinAttrs() *workloadidentityv1pb.JoinAttrsTerraformCloud { + return &workloadidentityv1pb.JoinAttrsTerraformCloud{ + Sub: c.Sub, + OrganizationName: c.OrganizationName, + ProjectName: c.ProjectName, + WorkspaceName: c.WorkspaceName, + FullWorkspace: c.FullWorkspace, + RunId: c.RunID, + RunPhase: c.RunPhase, } - - if err := d.Decode(c); err != nil { - return nil, trace.Wrap(err) - } - return res, nil }