diff --git a/lib/auth/bot.go b/lib/auth/bot.go index 01bb8d6bd2b59..840fe80195444 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -313,7 +313,7 @@ func (a *Server) updateBotInstance( if templateAuthRecord != nil { authRecord.JoinToken = templateAuthRecord.JoinToken authRecord.JoinMethod = templateAuthRecord.JoinMethod - authRecord.Metadata = templateAuthRecord.Metadata + authRecord.JoinAttrs = templateAuthRecord.JoinAttrs } // An empty bot instance most likely means a bot is rejoining after an diff --git a/lib/auth/join.go b/lib/auth/join.go index 0c6f0ecb1df29..afaa9a47a6561 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -118,7 +118,7 @@ func setRemoteAddrFromContext(ctx context.Context, req *types.RegisterUsingToken func (a *Server) handleJoinFailure( origErr error, pt types.ProvisionToken, - attributes any, + rawJoinAttrs any, req *types.RegisterUsingTokenRequest, ) { fields := logrus.Fields{} @@ -129,10 +129,13 @@ func (a *Server) handleJoinFailure( fields["remote_addr"] = req.RemoteAddr } - // Fetch and encode attributes if they are available. - attributesStruct, err := untypedAttrsToStruct(attributes) + // Fetch and encode rawJoinAttrs if they are available. + attributesStruct, err := rawJoinAttrsToStruct(rawJoinAttrs) if err != nil { - log.WithError(err).Warn("Unable to encode join attributes for audit event.") + log.WithError(err).Warn("Unable to encode join rawJoinAttrs for audit event.") + } + if attributesStruct != nil { + fields["attributes"] = attributesStruct } // Add log fields from token if available. @@ -204,15 +207,14 @@ func (a *Server) handleJoinFailure( // will be checked. func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (certs *proto.Certs, err error) { attrs := &workloadidentityv1pb.JoinAttrs{} - // untypedAttrs holds the unstructured join attributes specific to that - // join method for the purposes of including in the audit logs. - // Realistically, this can hold anything that can be JSON marshaled. - var untypedAttrs any + // rawJoinAttrs typically holds the raw metadata sourced from a join. + // E.g the claims from a JWT token. This is used for auditing purposes. + var rawJoinAttrs any var provisionToken types.ProvisionToken defer func() { // Emit a log message and audit event on join failure. if err != nil { - a.handleJoinFailure(err, provisionToken, untypedAttrs, req) + a.handleJoinFailure(err, provisionToken, rawJoinAttrs, req) } }() @@ -234,7 +236,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGitHub: claims, err := a.checkGitHubJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims attrs.Github = claims.JoinAttrs() } if err != nil { @@ -243,7 +245,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGitLab: claims, err := a.checkGitLabJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims attrs.Gitlab = claims.JoinAttrs() } if err != nil { @@ -252,7 +254,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodCircleCI: claims, err := a.checkCircleCIJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims } if err != nil { return nil, trace.Wrap(err) @@ -260,7 +262,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodKubernetes: claims, err := a.checkKubernetesJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims } if err != nil { return nil, trace.Wrap(err) @@ -268,7 +270,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodGCP: claims, err := a.checkGCPJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims } if err != nil { return nil, trace.Wrap(err) @@ -276,7 +278,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodSpacelift: claims, err := a.checkSpaceliftJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims } if err != nil { return nil, trace.Wrap(err) @@ -284,7 +286,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodTerraformCloud: claims, err := a.checkTerraformCloudJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims } if err != nil { return nil, trace.Wrap(err) @@ -292,7 +294,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin case types.JoinMethodBitbucket: claims, err := a.checkBitbucketJoinRequest(ctx, req) if claims != nil { - untypedAttrs = claims + rawJoinAttrs = claims } if err != nil { return nil, trace.Wrap(err) @@ -319,12 +321,12 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin ctx, provisionToken, req, - untypedAttrs, + rawJoinAttrs, attrs, ) return certs, trace.Wrap(err) } - certs, err = a.generateCerts(ctx, provisionToken, req, untypedAttrs) + certs, err = a.generateCerts(ctx, provisionToken, req, rawJoinAttrs) return certs, trace.Wrap(err) } @@ -332,7 +334,7 @@ func (a *Server) generateCertsBot( ctx context.Context, provisionToken types.ProvisionToken, req *types.RegisterUsingTokenRequest, - untypedAttrs any, + rawJoinAttrs any, attrs *workloadidentityv1pb.JoinAttrs, ) (*proto.Certs, error) { // bots use this endpoint but get a user cert @@ -382,7 +384,7 @@ func (a *Server) generateCertsBot( }, } var err error - joinEvent.Attributes, err = untypedAttrsToStruct(untypedAttrs) + joinEvent.Attributes, err = rawJoinAttrsToStruct(rawJoinAttrs) if err != nil { log.WithError(err).Warn("Unable to encode join attributes for audit event.") } @@ -511,7 +513,7 @@ func (a *Server) generateCerts( RemoteAddr: req.RemoteAddr, }, } - joinEvent.Attributes, err = untypedAttrsToStruct(untypedAttrs) + joinEvent.Attributes, err = rawJoinAttrsToStruct(untypedAttrs) if err != nil { log.WithError(err).Warn("Unable to encode join attributes for audit event.") } @@ -521,7 +523,7 @@ func (a *Server) generateCerts( return certs, nil } -func untypedAttrsToStruct(in any) (*apievents.Struct, error) { +func rawJoinAttrsToStruct(in any) (*apievents.Struct, error) { if in == nil { return nil, nil }