From 8b4fb0b75f6b93078d19e4900f6b317f8751b70c Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 29 Nov 2024 12:14:11 +0000 Subject: [PATCH] Fix nullyness --- lib/auth/join.go | 13 ++--------- lib/auth/join_azure.go | 1 + lib/auth/join_iam.go | 50 ++++++++++++++++++++++++++++++++---------- lib/auth/join_tpm.go | 17 ++++++++++---- lib/tpm/validate.go | 21 +++++++++++------- 5 files changed, 67 insertions(+), 35 deletions(-) diff --git a/lib/auth/join.go b/lib/auth/join.go index ad85d189abe02..0c6f0ecb1df29 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -31,9 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" "google.golang.org/grpc/peer" - "google.golang.org/protobuf/encoding/protojson" - googleproto "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport/api/client/proto" @@ -100,12 +97,6 @@ func (a *Server) checkTokenJoinRequestCommon(ctx context.Context, req *types.Reg return provisionToken, nil } -type joinAttributeSourcer interface { - // JoinAuditAttributes returns a series of attributes that can be inserted into - // audit events related to a specific join. - JoinAuditAttributes() (map[string]interface{}, error) -} - func setRemoteAddrFromContext(ctx context.Context, req *types.RegisterUsingTokenRequest) error { var addr string if clientIP, err := authz.ClientSrcAddrFromContext(ctx); err == nil { @@ -244,7 +235,7 @@ func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsin claims, err := a.checkGitHubJoinRequest(ctx, req) if claims != nil { untypedAttrs = claims - attrs.Gitlab = claims.JoinAttrs() + attrs.Github = claims.JoinAttrs() } if err != nil { return nil, trace.Wrap(err) @@ -401,7 +392,7 @@ func (a *Server) generateCertsBot( if attrs == nil { attrs = &workloadidentityv1pb.JoinAttrs{} } - attrs.Meta = &workloadidentityv1pb.MetaJoinAttrs{ + attrs.Meta = &workloadidentityv1pb.JoinAttrsMeta{ JoinMethod: string(joinMethod), } if joinMethod != types.JoinMethodToken { diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index 428ac15c90679..467fc697692e6 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -409,6 +409,7 @@ func (a *Server) RegisterUsingAzureMethodWithOpts( provisionToken, req.RegisterUsingTokenRequest, nil, + nil, ) return certs, trace.Wrap(err) } diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index ba2209105c7c0..ee8c97857ee1f 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/join/iam" "github.com/gravitational/teleport/lib/utils" @@ -172,6 +173,18 @@ type awsIdentity struct { Arn string `json:"Arn"` } +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *awsIdentity) JoinAttrs() *workloadidentityv1pb.JoinAttrsAWSIAM { + attrs := &workloadidentityv1pb.JoinAttrsAWSIAM{ + Account: c.Account, + Arn: c.Arn, + } + + return attrs +} + // getCallerIdentityReponse is used for JSON parsing type getCallerIdentityResponse struct { GetCallerIdentityResult awsIdentity `json:"GetCallerIdentityResult"` @@ -260,41 +273,41 @@ func checkIAMAllowRules(identity *awsIdentity, token string, allowRules []*types // checkIAMRequest checks if the given request satisfies the token rules and // included the required challenge. -func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *proto.RegisterUsingIAMMethodRequest, cfg *iamRegisterConfig) error { +func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *proto.RegisterUsingIAMMethodRequest, cfg *iamRegisterConfig) (*awsIdentity, error) { tokenName := req.RegisterUsingTokenRequest.Token provisionToken, err := a.GetToken(ctx, tokenName) if err != nil { - return trace.Wrap(err, "getting token") + return nil, trace.Wrap(err, "getting token") } if provisionToken.GetJoinMethod() != types.JoinMethodIAM { - return trace.AccessDenied("this token does not support the IAM join method") + return nil, trace.AccessDenied("this token does not support the IAM join method") } // parse the incoming http request to the sts:GetCallerIdentity endpoint identityRequest, err := parseSTSRequest(req.StsIdentityRequest) if err != nil { - return trace.Wrap(err, "parsing STS request") + return nil, trace.Wrap(err, "parsing STS request") } // validate that the host, method, and headers are correct and the expected // challenge is included in the signed portion of the request if err := validateSTSIdentityRequest(identityRequest, challenge, cfg); err != nil { - return trace.Wrap(err, "validating STS request") + return nil, trace.Wrap(err, "validating STS request") } // send the signed request to the public AWS API and get the node identity // from the response identity, err := executeSTSIdentityRequest(ctx, a.httpClientForAWSSTS, identityRequest) if err != nil { - return trace.Wrap(err, "executing STS request") + return nil, trace.Wrap(err, "executing STS request") } // check that the node identity matches an allow rule for this token if err := checkIAMAllowRules(identity, provisionToken.GetName(), provisionToken.GetAllowRules()); err != nil { - return trace.Wrap(err, "checking allow rules") + return identity, trace.Wrap(err, "checking allow rules") } - return nil + return identity, nil } func generateIAMChallenge() (string, error) { @@ -341,11 +354,12 @@ func (a *Server) RegisterUsingIAMMethodWithOpts( ) (certs *proto.Certs, err error) { var provisionToken types.ProvisionToken var joinRequest *types.RegisterUsingTokenRequest + var joinFailureMetadata any defer func() { // Emit a log message and audit event on join failure. if err != nil { a.handleJoinFailure( - err, provisionToken, nil, joinRequest, + err, provisionToken, joinFailureMetadata, joinRequest, ) } }() @@ -377,15 +391,27 @@ func (a *Server) RegisterUsingIAMMethodWithOpts( } // check that the GetCallerIdentity request is valid and matches the token - if err := a.checkIAMRequest(ctx, challenge, req, cfg); err != nil { + verifiedIdentity, err := a.checkIAMRequest(ctx, challenge, req, cfg) + if verifiedIdentity != nil { + joinFailureMetadata = verifiedIdentity + } + if err != nil { return nil, trace.Wrap(err, "checking iam request") } if req.RegisterUsingTokenRequest.Role == types.RoleBot { - certs, err := a.generateCertsBot(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) + certs, err := a.generateCertsBot( + ctx, + provisionToken, + req.RegisterUsingTokenRequest, + verifiedIdentity, + &workloadidentityv1pb.JoinAttrs{ + Iam: verifiedIdentity.JoinAttrs(), + }, + ) return certs, trace.Wrap(err, "generating bot certs") } - certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) + certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, verifiedIdentity) return certs, trace.Wrap(err, "generating certs") } diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go index 05bf9e3c35a54..cdfad3e680350 100644 --- a/lib/auth/join_tpm.go +++ b/lib/auth/join_tpm.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/tpm" @@ -39,12 +40,12 @@ func (a *Server) RegisterUsingTPMMethod( solveChallenge client.RegisterTPMChallengeResponseFunc, ) (_ *proto.Certs, err error) { var provisionToken types.ProvisionToken - var attributeSrc joinAttributeSourcer + var joinFailureMetadata any defer func() { // Emit a log message and audit event on join failure. if err != nil { a.handleJoinFailure( - err, provisionToken, attributeSrc, initReq.JoinRequest, + err, provisionToken, joinFailureMetadata, initReq.JoinRequest, ) } }() @@ -99,10 +100,12 @@ func (a *Server) RegisterUsingTPMMethod( return solution.Solution, nil }, }) + if validatedEK != nil { + joinFailureMetadata = validatedEK + } if err != nil { return nil, trace.Wrap(err, "validating TPM EK") } - attributeSrc = validatedEK if err := checkTPMAllowRules(validatedEK, ptv2.Spec.TPM.Allow); err != nil { return nil, trace.Wrap(err) @@ -110,7 +113,13 @@ func (a *Server) RegisterUsingTPMMethod( if initReq.JoinRequest.Role == types.RoleBot { certs, err := a.generateCertsBot( - ctx, ptv2, initReq.JoinRequest, validatedEK, + ctx, + ptv2, + initReq.JoinRequest, + validatedEK, + &workloadidentityv1pb.JoinAttrs{ + Tpm: validatedEK.JoinAttrs(), + }, ) return certs, trace.Wrap(err, "generating certs for bot") } diff --git a/lib/tpm/validate.go b/lib/tpm/validate.go index 268857d35e4ff..126133d31e644 100644 --- a/lib/tpm/validate.go +++ b/lib/tpm/validate.go @@ -27,6 +27,8 @@ import ( "github.com/google/go-attestation/attest" "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" ) // ValidateParams are the parameters required to validate a TPM. @@ -63,14 +65,17 @@ type ValidatedTPM struct { EKCertVerified bool `json:"ek_cert_verified"` } -// JoinAuditAttributes returns a series of attributes that can be inserted into -// audit events related to a specific join. -func (c *ValidatedTPM) JoinAuditAttributes() (map[string]interface{}, error) { - return map[string]interface{}{ - "ek_pub_hash": c.EKPubHash, - "ek_cert_serial": c.EKCertSerial, - "ek_cert_verified": c.EKCertVerified, - }, nil +// JoinAttrs returns the protobuf representation of the attested identity. +// This is used for auditing and for evaluation of WorkloadIdentity rules and +// templating. +func (c *ValidatedTPM) JoinAttrs() *workloadidentityv1pb.JoinAttrsTPM { + attrs := &workloadidentityv1pb.JoinAttrsTPM{ + EkPubHash: c.EKPubHash, + EkCertSerial: c.EKCertSerial, + EkCertVerified: c.EKCertVerified, + } + + return attrs } // Validate takes the parameters from a remote TPM and performs the necessary