diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 43fbcbfdc4e62..f233963fea0fe 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5230,6 +5230,23 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { } workloadidentityv1pb.RegisterWorkloadIdentityResourceServiceServer(server, workloadIdentityResourceService) + clusterName, err := cfg.AuthServer.GetClusterName() + if err != nil { + return nil, trace.Wrap(err, "getting cluster name") + } + workloadIdentityIssuanceService, err := workloadidentityv1.NewIssuanceService(&workloadidentityv1.IssuanceServiceConfig{ + Authorizer: cfg.Authorizer, + Cache: cfg.AuthServer.Cache, + Emitter: cfg.Emitter, + Clock: cfg.AuthServer.GetClock(), + KeyStore: cfg.AuthServer.keyStore, + ClusterName: clusterName.GetClusterName(), + }) + if err != nil { + return nil, trace.Wrap(err, "creating workload identity issuance service") + } + workloadidentityv1pb.RegisterWorkloadIdentityIssuanceServiceServer(server, workloadIdentityIssuanceService) + dbObjectImportRuleService, err := dbobjectimportrulev1.NewDatabaseObjectImportRuleService(dbobjectimportrulev1.DatabaseObjectImportRuleServiceConfig{ Authorizer: cfg.Authorizer, Backend: cfg.AuthServer.Services, diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index d4242442ab85f..b04c24e80f190 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -1324,11 +1324,38 @@ func CreateUser(ctx context.Context, clt clt, username string, roles ...types.Ro return created, trace.Wrap(err) } +// createUserAndRoleOptions is a set of options for CreateUserAndRole +type createUserAndRoleOptions struct { + mutateUser []func(user types.User) + mutateRole []func(role types.Role) +} + +// CreateUserAndRoleOption is a functional option for CreateUserAndRole +type CreateUserAndRoleOption func(*createUserAndRoleOptions) + +// WithUserMutator sets a function that will be called to mutate the user before it is created +func WithUserMutator(mutate ...func(user types.User)) CreateUserAndRoleOption { + return func(o *createUserAndRoleOptions) { + o.mutateUser = append(o.mutateUser, mutate...) + } +} + +// WithRoleMutator sets a function that will be called to mutate the role before it is created +func WithRoleMutator(mutate ...func(role types.Role)) CreateUserAndRoleOption { + return func(o *createUserAndRoleOptions) { + o.mutateRole = append(o.mutateRole, mutate...) + } +} + // CreateUserAndRole creates user and role and assigns role to a user, used in tests // If allowRules is nil, the role has admin privileges. // If allowRules is not-nil, then the rules associated with the role will be // replaced with those specified. -func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRules []types.Rule) (types.User, types.Role, error) { +func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRules []types.Rule, opts ...CreateUserAndRoleOption) (types.User, types.Role, error) { + o := createUserAndRoleOptions{} + for _, opt := range opts { + opt(&o) + } ctx := context.TODO() user, err := types.NewUser(username) if err != nil { @@ -1340,6 +1367,9 @@ func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRu if allowRules != nil { role.SetRules(types.Allow, allowRules) } + for _, mutate := range o.mutateRole { + mutate(role) + } upsertedRole, err := clt.UpsertRole(ctx, role) if err != nil { @@ -1347,6 +1377,9 @@ func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRu } user.AddRole(upsertedRole.GetName()) + for _, mutate := range o.mutateUser { + mutate(user) + } created, err := clt.UpsertUser(ctx, user) if err != nil { return nil, nil, trace.Wrap(err) diff --git a/lib/auth/machineid/workloadidentityv1/experiment/experiment.go b/lib/auth/machineid/workloadidentityv1/experiment/experiment.go new file mode 100644 index 0000000000000..fafe51ea83d1f --- /dev/null +++ b/lib/auth/machineid/workloadidentityv1/experiment/experiment.go @@ -0,0 +1,41 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package experiment + +import ( + "os" + "sync" +) + +var mu sync.Mutex + +var experimentEnabled = os.Getenv("TELEPORT_WORKLOAD_IDENTITY_UX_EXPERIMENT") == "1" + +// Enabled returns true if the workload identity UX experiment is +// enabled. +func Enabled() bool { + mu.Lock() + defer mu.Unlock() + return experimentEnabled +} + +// SetEnabled sets the experiment enabled flag. +func SetEnabled(enabled bool) { + mu.Lock() + defer mu.Unlock() + experimentEnabled = enabled +} diff --git a/lib/auth/machineid/workloadidentityv1/issuer_service.go b/lib/auth/machineid/workloadidentityv1/issuer_service.go new file mode 100644 index 0000000000000..70a7fa1197974 --- /dev/null +++ b/lib/auth/machineid/workloadidentityv1/issuer_service.go @@ -0,0 +1,608 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package workloadidentityv1 + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/x509" + "log/slog" + "math/big" + "net/url" + "regexp" + "slices" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "go.opentelemetry.io/otel" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/gravitational/teleport" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment" + "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/jwt" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/oidc" +) + +var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1") + +// KeyStorer is an interface that provides methods to retrieve keys and +// certificates from the backend. +type KeyStorer interface { + GetTLSCertAndSigner(ctx context.Context, ca types.CertAuthority) ([]byte, crypto.Signer, error) + GetJWTSigner(ctx context.Context, ca types.CertAuthority) (crypto.Signer, error) +} + +type issuerCache interface { + workloadIdentityReader + GetProxies() ([]types.Server, error) + GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) +} + +// IssuanceServiceConfig holds configuration options for the IssuanceService. +type IssuanceServiceConfig struct { + Authorizer authz.Authorizer + Cache issuerCache + Clock clockwork.Clock + Emitter apievents.Emitter + Logger *slog.Logger + KeyStore KeyStorer + + ClusterName string +} + +// IssuanceService is the gRPC service for managing workload identity resources. +// It implements the workloadidentityv1pb.WorkloadIdentityIssuanceServiceServer. +type IssuanceService struct { + workloadidentityv1pb.UnimplementedWorkloadIdentityIssuanceServiceServer + + authorizer authz.Authorizer + cache issuerCache + clock clockwork.Clock + emitter apievents.Emitter + logger *slog.Logger + keyStore KeyStorer + + clusterName string +} + +// NewIssuanceService returns a new instance of the IssuanceService. +func NewIssuanceService(cfg *IssuanceServiceConfig) (*IssuanceService, error) { + switch { + case cfg.Cache == nil: + return nil, trace.BadParameter("cache service is required") + case cfg.Authorizer == nil: + return nil, trace.BadParameter("authorizer is required") + case cfg.Emitter == nil: + return nil, trace.BadParameter("emitter is required") + case cfg.KeyStore == nil: + return nil, trace.BadParameter("key store is required") + case cfg.ClusterName == "": + return nil, trace.BadParameter("cluster name is required") + } + + if cfg.Logger == nil { + cfg.Logger = slog.With(teleport.ComponentKey, "workload_identity_issuance.service") + } + if cfg.Clock == nil { + cfg.Clock = clockwork.NewRealClock() + } + return &IssuanceService{ + authorizer: cfg.Authorizer, + cache: cfg.Cache, + clock: cfg.Clock, + emitter: cfg.Emitter, + logger: cfg.Logger, + keyStore: cfg.KeyStore, + clusterName: cfg.ClusterName, + }, nil +} + +// getFieldStringValue returns a string value from the given attribute set. +// The attribute is specified as a dot-separated path to the field in the +// attribute set. +// +// The specified attribute must be a string field. If the attribute is not +// found, an error is returned. +// +// TODO(noah): This function will be replaced by the Teleport predicate language +// in a coming PR. +func getFieldStringValue(attrs *workloadidentityv1pb.Attrs, attr string) (string, error) { + attrParts := strings.Split(attr, ".") + message := attrs.ProtoReflect() + // TODO(noah): Improve errors by including the fully qualified attribute + // (e.g add up the parts of the attribute path processed thus far) + for i, part := range attrParts { + fieldDesc := message.Descriptor().Fields().ByTextName(part) + if fieldDesc == nil { + return "", trace.NotFound("attribute %q not found", part) + } + // We expect the final key to point to a string field - otherwise - we + // return an error. + if i == len(attrParts)-1 { + if !slices.Contains([]protoreflect.Kind{ + protoreflect.StringKind, + protoreflect.BoolKind, + protoreflect.Int32Kind, + protoreflect.Int64Kind, + protoreflect.Uint64Kind, + protoreflect.Uint32Kind, + }, fieldDesc.Kind()) { + return "", trace.BadParameter("attribute %q of type %q cannot be converted to string", part, fieldDesc.Kind()) + } + return message.Get(fieldDesc).String(), nil + } + // If we're not processing the final key part, we expect this to point + // to a message that we can further explore. + if fieldDesc.Kind() != protoreflect.MessageKind { + return "", trace.BadParameter("attribute %q is not a message", part) + } + message = message.Get(fieldDesc).Message() + } + return "", nil +} + +// templateString takes a given input string and replaces any values within +// {{ }} with values from the attribute set. +// +// If the specified value is not found in the attribute set, an error is +// returned. +// +// TODO(noah): In a coming PR, this will be replaced by evaluating the values +// within the handlebars as expressions. +func templateString(in string, attrs *workloadidentityv1pb.Attrs) (string, error) { + re := regexp.MustCompile(`\{\{([^{}]+?)\}\}`) + matches := re.FindAllStringSubmatch(in, -1) + + for _, match := range matches { + attrKey := strings.TrimSpace(match[1]) + value, err := getFieldStringValue(attrs, attrKey) + if err != nil { + return "", trace.Wrap(err, "fetching attribute value for %q", attrKey) + } + // We want to have an implicit rule here that if an attribute is + // included in the template, but is not set, we should refuse to issue + // the credential. + if value == "" { + return "", trace.NotFound("attribute %q unset", attrKey) + } + in = strings.Replace(in, match[0], value, 1) + } + + return in, nil +} + +func evaluateRules( + wi *workloadidentityv1pb.WorkloadIdentity, + attrs *workloadidentityv1pb.Attrs, +) error { + if len(wi.GetSpec().GetRules().GetAllow()) == 0 { + return nil + } +ruleLoop: + for _, rule := range wi.GetSpec().GetRules().GetAllow() { + for _, condition := range rule.GetConditions() { + val, err := getFieldStringValue(attrs, condition.Attribute) + if err != nil { + return trace.Wrap(err) + } + if val != condition.Equals { + continue ruleLoop + } + } + return nil + } + // TODO: Eventually, we'll need to work support for deny rules into here. + return trace.AccessDenied("no matching rule found") +} + +func (s *IssuanceService) deriveAttrs( + authzCtx *authz.Context, + workloadAttrs *workloadidentityv1pb.WorkloadAttrs, +) (*workloadidentityv1pb.Attrs, error) { + attrs := &workloadidentityv1pb.Attrs{ + Workload: workloadAttrs, + User: &workloadidentityv1pb.UserAttrs{ + Name: authzCtx.Identity.GetIdentity().Username, + IsBot: authzCtx.Identity.GetIdentity().BotName != "", + BotName: authzCtx.Identity.GetIdentity().BotName, + Labels: authzCtx.User.GetAllLabels(), + }, + } + + return attrs, nil +} + +var defaultMaxTTL = 24 * time.Hour + +func (s *IssuanceService) IssueWorkloadIdentity( + ctx context.Context, + req *workloadidentityv1pb.IssueWorkloadIdentityRequest, +) (*workloadidentityv1pb.IssueWorkloadIdentityResponse, error) { + if !experiment.Enabled() { + return nil, trace.AccessDenied("workload identity issuance experiment is disabled") + } + + authCtx, err := s.authorizer.Authorize(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + switch { + case req.GetName() == "": + return nil, trace.BadParameter("name: is required") + case req.GetCredential() == nil: + return nil, trace.BadParameter("at least one credential type must be requested") + } + + wi, err := s.cache.GetWorkloadIdentity(ctx, req.GetName()) + if err != nil { + return nil, trace.Wrap(err) + } + // Check the principal has access to the workload identity resource by + // virtue of WorkloadIdentityLabels on a role. + if err := authCtx.Checker.CheckAccess( + types.Resource153ToResourceWithLabels(wi), + services.AccessState{}, + ); err != nil { + return nil, trace.Wrap(err) + } + + attrs, err := s.deriveAttrs(authCtx, req.GetWorkloadAttrs()) + if err != nil { + return nil, trace.Wrap(err, "deriving attributes") + } + // Evaluate any rules explicitly configured by the user + if err := evaluateRules(wi, attrs); err != nil { + return nil, trace.Wrap(err) + } + + // Perform any templating + spiffeIDPath, err := templateString(wi.GetSpec().GetSpiffe().GetId(), attrs) + if err != nil { + return nil, trace.Wrap(err, "templating spec.spiffe.id") + } + spiffeID, err := spiffeid.FromURI(&url.URL{ + Scheme: "spiffe", + Host: s.clusterName, + Path: spiffeIDPath, + }) + if err != nil { + return nil, trace.Wrap(err, "creating SPIFFE ID") + } + + hint, err := templateString(wi.GetSpec().GetSpiffe().GetHint(), attrs) + if err != nil { + return nil, trace.Wrap(err, "templating spec.spiffe.hint") + } + + // TODO(noah): Add more sophisticated control of the TTL. + ttl := time.Hour + if req.RequestedTtl != nil && req.RequestedTtl.AsDuration() != 0 { + ttl = req.RequestedTtl.AsDuration() + if ttl > defaultMaxTTL { + ttl = defaultMaxTTL + } + } + + now := s.clock.Now() + notBefore := now.Add(-1 * time.Minute) + notAfter := now.Add(ttl) + + // Prepare event + evt := &apievents.SPIFFESVIDIssued{ + Metadata: apievents.Metadata{ + Type: events.SPIFFESVIDIssuedEvent, + Code: events.SPIFFESVIDIssuedSuccessCode, + }, + UserMetadata: authz.ClientUserMetadata(ctx), + ConnectionMetadata: authz.ConnectionMetadata(ctx), + SPIFFEID: spiffeID.String(), + Hint: hint, + WorkloadIdentity: wi.GetMetadata().GetName(), + WorkloadIdentityRevision: wi.GetMetadata().GetRevision(), + } + cred := &workloadidentityv1pb.Credential{ + WorkloadIdentityName: wi.GetMetadata().GetName(), + WorkloadIdentityRevision: wi.GetMetadata().GetRevision(), + + SpiffeId: spiffeID.String(), + Hint: hint, + + ExpiresAt: timestamppb.New(notAfter), + Ttl: durationpb.New(ttl), + } + + switch v := req.GetCredential().(type) { + case *workloadidentityv1pb.IssueWorkloadIdentityRequest_X509SvidParams: + evt.SVIDType = "x509" + certDer, certSerial, err := s.issueX509SVID( + ctx, + v.X509SvidParams, + notBefore, + notAfter, + spiffeID, + ) + if err != nil { + return nil, trace.Wrap(err, "issuing X509 SVID") + } + serialStr := serialString(certSerial) + cred.Credential = &workloadidentityv1pb.Credential_X509Svid{ + X509Svid: &workloadidentityv1pb.X509SVIDCredential{ + Cert: certDer, + SerialNumber: serialStr, + }, + } + evt.SerialNumber = serialStr + case *workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams: + evt.SVIDType = "jwt" + signedJwt, jti, err := s.issueJWTSVID( + ctx, + v.JwtSvidParams, + now, + notAfter, + spiffeID, + ) + if err != nil { + return nil, trace.Wrap(err, "issuing JWT SVID") + } + cred.Credential = &workloadidentityv1pb.Credential_JwtSvid{ + JwtSvid: &workloadidentityv1pb.JWTSVIDCredential{ + Jwt: signedJwt, + Jti: jti, + }, + } + evt.JTI = jti + default: + return nil, trace.BadParameter("credential: unknown type %T", req.GetCredential()) + } + + if err := s.emitter.EmitAuditEvent(ctx, evt); err != nil { + s.logger.WarnContext( + ctx, + "failed to emit audit event for SVID issuance", + "error", err, + "event", evt, + ) + } + + return &workloadidentityv1pb.IssueWorkloadIdentityResponse{ + Credential: cred, + }, nil +} + +func generateCertSerial() (*big.Int, error) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + return rand.Int(rand.Reader, serialNumberLimit) +} + +func x509Template( + serialNumber *big.Int, + notBefore time.Time, + notAfter time.Time, + spiffeID spiffeid.ID, +) *x509.Certificate { + return &x509.Certificate{ + SerialNumber: serialNumber, + NotBefore: notBefore, + NotAfter: notAfter, + // SPEC(X509-SVID) 4.3. Key Usage: + // - Leaf SVIDs MUST NOT set keyCertSign or cRLSign. + // - Leaf SVIDs MUST set digitalSignature + // - They MAY set keyEncipherment and/or keyAgreement; + KeyUsage: x509.KeyUsageDigitalSignature | + x509.KeyUsageKeyEncipherment | + x509.KeyUsageKeyAgreement, + // SPEC(X509-SVID) 4.4. Extended Key Usage: + // - Leaf SVIDs SHOULD include this extension, and it MAY be marked as critical. + // - When included, fields id-kp-serverAuth and id-kp-clientAuth MUST be set. + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth, + }, + // SPEC(X509-SVID) 4.1. Basic Constraints: + // - leaf certificates MUST set the cA field to false + BasicConstraintsValid: true, + IsCA: false, + + // SPEC(X509-SVID) 2. SPIFFE ID: + // - The corresponding SPIFFE ID is set as a URI type in the Subject Alternative Name extension + // - An X.509 SVID MUST contain exactly one URI SAN, and by extension, exactly one SPIFFE ID. + // - An X.509 SVID MAY contain any number of other SAN field types, including DNS SANs. + URIs: []*url.URL{spiffeID.URL()}, + } +} + +func (s *IssuanceService) getX509CA( + ctx context.Context, +) (_ *tlsca.CertAuthority, err error) { + ctx, span := tracer.Start(ctx, "IssuanceService/getX509CA") + defer func() { tracing.EndSpan(span, err) }() + + ca, err := s.cache.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.SPIFFECA, + DomainName: s.clusterName, + }, true) + tlsCert, tlsSigner, err := s.keyStore.GetTLSCertAndSigner(ctx, ca) + if err != nil { + return nil, trace.Wrap(err, "getting CA cert and key") + } + tlsCA, err := tlsca.FromCertAndSigner(tlsCert, tlsSigner) + if err != nil { + return nil, trace.Wrap(err) + } + return tlsCA, nil +} + +func (s *IssuanceService) issueX509SVID( + ctx context.Context, + params *workloadidentityv1pb.X509SVIDParams, + notBefore time.Time, + notAfter time.Time, + spiffeID spiffeid.ID, +) (_ []byte, _ *big.Int, err error) { + ctx, span := tracer.Start(ctx, "IssuanceService/issueX509SVID") + defer func() { tracing.EndSpan(span, err) }() + + switch { + case params == nil: + return nil, nil, trace.BadParameter("x509_svid_params: is required") + case len(params.PublicKey) == 0: + return nil, nil, trace.BadParameter("x509_svid_params.public_key: is required") + } + + pubKey, err := x509.ParsePKIXPublicKey(params.PublicKey) + if err != nil { + return nil, nil, trace.Wrap(err, "parsing public key") + } + + certSerial, err := generateCertSerial() + if err != nil { + return nil, nil, trace.Wrap(err, "generating certificate serial") + } + template := x509Template(certSerial, notBefore, notAfter, spiffeID) + + ca, err := s.getX509CA(ctx) + if err != nil { + return nil, nil, trace.Wrap(err, "fetching CA to sign X509 SVID") + } + certBytes, err := x509.CreateCertificate( + rand.Reader, template, ca.Cert, pubKey, ca.Signer, + ) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return certBytes, certSerial, nil +} + +const jtiLength = 16 + +func (s *IssuanceService) getJWTIssuerKey( + ctx context.Context, +) (_ *jwt.Key, err error) { + ctx, span := tracer.Start(ctx, "IssuanceService/getJWTIssuerKey") + defer func() { tracing.EndSpan(span, err) }() + + ca, err := s.cache.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.SPIFFECA, + DomainName: s.clusterName, + }, true) + if err != nil { + return nil, trace.Wrap(err, "getting SPIFFE CA") + } + + jwtSigner, err := s.keyStore.GetJWTSigner(ctx, ca) + if err != nil { + return nil, trace.Wrap(err, "getting JWT signer") + } + + jwtKey, err := services.GetJWTSigner( + jwtSigner, s.clusterName, s.clock, + ) + if err != nil { + return nil, trace.Wrap(err, "creating JWT signer") + } + return jwtKey, nil +} + +func (s *IssuanceService) issueJWTSVID( + ctx context.Context, + params *workloadidentityv1pb.JWTSVIDParams, + now time.Time, + notAfter time.Time, + spiffeID spiffeid.ID, +) (_ string, _ string, err error) { + ctx, span := tracer.Start(ctx, "IssuanceService/issueJWTSVID") + defer func() { tracing.EndSpan(span, err) }() + + switch { + case params == nil: + return "", "", trace.BadParameter("jwt_svid_params: is required") + case len(params.Audiences) == 0: + return "", "", trace.BadParameter("jwt_svid_params.audiences: at least one audience should be specified") + } + + jti, err := utils.CryptoRandomHex(jtiLength) + if err != nil { + return "", "", trace.Wrap(err, "generating JTI") + } + + key, err := s.getJWTIssuerKey(ctx) + if err != nil { + return "", "", trace.Wrap(err, "getting JWT issuer key") + } + + // Determine the public address of the proxy for inclusion in the JWT as + // the issuer for purposes of OIDC compatibility. + issuer, err := oidc.IssuerForCluster(ctx, s.cache, "/workload-identity") + if err != nil { + return "", "", trace.Wrap(err, "determining issuer URI") + } + + signed, err := key.SignJWTSVID(jwt.SignParamsJWTSVID{ + Audiences: params.Audiences, + SPIFFEID: spiffeID, + JTI: jti, + Issuer: issuer, + + SetIssuedAt: now, + SetExpiry: notAfter, + }) + if err != nil { + return "", "", trace.Wrap(err, "signing jwt") + } + + return signed, jti, nil +} + +func (s *IssuanceService) IssueWorkloadIdentities( + ctx context.Context, + req *workloadidentityv1pb.IssueWorkloadIdentitiesRequest, +) (*workloadidentityv1pb.IssueWorkloadIdentitiesResponse, error) { + // TODO(noah): Coming to a PR near you soon! + return nil, trace.NotImplemented("not implemented") +} + +func serialString(serial *big.Int) string { + hex := serial.Text(16) + if len(hex)%2 == 1 { + hex = "0" + hex + } + + out := strings.Builder{} + for i := 0; i < len(hex); i += 2 { + if i != 0 { + out.WriteString(":") + } + out.WriteString(hex[i : i+2]) + } + return out.String() +} diff --git a/lib/auth/machineid/workloadidentityv1/issuer_service_test.go b/lib/auth/machineid/workloadidentityv1/issuer_service_test.go new file mode 100644 index 0000000000000..bf18594416609 --- /dev/null +++ b/lib/auth/machineid/workloadidentityv1/issuer_service_test.go @@ -0,0 +1,223 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package workloadidentityv1 + +import ( + "testing" + + "github.com/stretchr/testify/require" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +func Test_getFieldStringValue(t *testing.T) { + tests := []struct { + name string + in *workloadidentityv1pb.Attrs + attr string + want string + requireErr require.ErrorAssertionFunc + }{ + { + name: "success", + in: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + }, + attr: "user.name", + want: "jeff", + requireErr: require.NoError, + }, + { + name: "bool", + in: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + Workload: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ + Attested: true, + }, + }, + }, + attr: "workload.unix.attested", + want: "true", + requireErr: require.NoError, + }, + { + name: "int32", + in: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + Workload: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ + Pid: 123, + }, + }, + }, + attr: "workload.unix.pid", + want: "123", + requireErr: require.NoError, + }, + { + name: "uint32", + in: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + Workload: &workloadidentityv1pb.WorkloadAttrs{ + Unix: &workloadidentityv1pb.WorkloadAttrsUnix{ + Gid: 123, + }, + }, + }, + attr: "workload.unix.gid", + want: "123", + requireErr: require.NoError, + }, + { + name: "non-string final field", + in: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "user", + }, + }, + attr: "user", + requireErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "attribute \"user\" of type \"message\" cannot be converted to string") + }, + }, + { + // We mostly just want this to not panic. + name: "nil root", + in: nil, + attr: "user.name", + want: "", + requireErr: require.NoError, + }, + { + // We mostly just want this to not panic. + name: "nil submessage", + in: &workloadidentityv1pb.Attrs{ + User: nil, + }, + attr: "user.name", + want: "", + requireErr: require.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotErr := getFieldStringValue(tt.in, tt.attr) + tt.requireErr(t, gotErr) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_templateString(t *testing.T) { + tests := []struct { + name string + in string + want string + attrs *workloadidentityv1pb.Attrs + requireErr require.ErrorAssertionFunc + }{ + { + name: "success mixed", + in: "hello{{user.name}}.{{user.name}} {{ workload.kubernetes.pod_name }}//{{ workload.kubernetes.namespace}}", + want: "hellojeff.jeff pod1//default", + attrs: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + Workload: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ + PodName: "pod1", + Namespace: "default", + }, + }, + }, + requireErr: require.NoError, + }, + { + name: "success with spaces", + in: "hello {{user.name}}", + want: "hello jeff", + attrs: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + }, + requireErr: require.NoError, + }, + { + name: "fail due to unset", + in: "hello {{workload.kubernetes.pod_name}}", + attrs: &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "jeff", + }, + }, + requireErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "attribute \"workload.kubernetes.pod_name\" unset") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotErr := templateString(tt.in, tt.attrs) + tt.requireErr(t, gotErr) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_evaluateRules(t *testing.T) { + attrs := &workloadidentityv1pb.Attrs{ + User: &workloadidentityv1pb.UserAttrs{ + Name: "foo", + }, + } + wi := &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Equals: "foo", + }, + }, + }, + }, + }, + }, + } + err := evaluateRules(wi, attrs) + require.NoError(t, err) +} diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index 1c0601a34dd54..b911885031995 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -18,6 +18,8 @@ package workloadidentityv1_test import ( "context" + "crypto/rsa" + "crypto/x509" "errors" "fmt" "net" @@ -26,6 +28,7 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3/jwt" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/trace" @@ -33,6 +36,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/durationpb" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" @@ -40,9 +44,13 @@ import ( "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment" + "github.com/gravitational/teleport/lib/auth/native" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" + libjwt "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/services" ) func TestMain(m *testing.M) { @@ -74,6 +82,396 @@ func newTestTLSServer(t testing.TB) (*auth.TestTLSServer, *eventstest.MockRecord return srv, emitter } +func TestIssueWorkloadIdentity(t *testing.T) { + experimentStatus := experiment.Enabled() + defer experiment.SetEnabled(experimentStatus) + experiment.SetEnabled(true) + + srv, eventRecorder := newTestTLSServer(t) + ctx := context.Background() + clock := srv.Auth().GetClock() + + // Upsert a fake proxy to ensure we have a public address to use for the + // issuer. + proxy, err := types.NewServer("proxy", types.KindProxy, types.ServerSpecV2{ + PublicAddrs: []string{"teleport.example.com"}, + }) + require.NoError(t, err) + err = srv.Auth().UpsertProxy(ctx, proxy) + require.NoError(t, err) + wantIssuer := "https://teleport.example.com/workload-identity" + + // Fetch X509 SPIFFE CA for validation of signature later + spiffeX509CA, err := srv.Auth().GetCertAuthority(ctx, types.CertAuthID{ + Type: types.SPIFFECA, + DomainName: srv.ClusterName(), + }, false) + require.NoError(t, err) + spiffeX509CAPool, err := services.CertPool(spiffeX509CA) + require.NoError(t, err) + // Fetch JWT CA to validate JWTs + jwtCA, err := srv.Auth().GetCertAuthority(ctx, types.CertAuthID{ + Type: types.SPIFFECA, + DomainName: "localhost", + }, true) + require.NoError(t, err) + jwtSigner, err := srv.Auth().GetKeyStore().GetJWTSigner(ctx, jwtCA) + require.NoError(t, err) + kid := libjwt.KeyID(jwtSigner.Public().(*rsa.PublicKey)) + + wildcardAccess, _, err := auth.CreateUserAndRole( + srv.Auth(), + "dog", + []string{}, + []types.Rule{}, + auth.WithRoleMutator(func(role types.Role) { + role.SetWorkloadIdentityLabels(types.Allow, types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }) + }), + ) + require.NoError(t, err) + wilcardAccessClient, err := srv.NewClient(auth.TestUser(wildcardAccess.GetName())) + require.NoError(t, err) + + specificAccess, _, err := auth.CreateUserAndRole( + srv.Auth(), + "cat", + []string{}, + []types.Rule{}, + auth.WithRoleMutator(func(role types.Role) { + role.SetWorkloadIdentityLabels(types.Allow, types.Labels{ + "foo": []string{"bar"}, + }) + }), + ) + require.NoError(t, err) + specificAccessClient, err := srv.NewClient(auth.TestUser(specificAccess.GetName())) + require.NoError(t, err) + + // Generate a keypair to generate x509 SVIDs for. + workloadKey, err := native.GenerateRSAPrivateKey() + require.NoError(t, err) + workloadKeyPubBytes, err := x509.MarshalPKIXPublicKey(workloadKey.Public()) + require.NoError(t, err) + + // Create some WorkloadIdentity resources + full, err := srv.Auth().CreateWorkloadIdentity(ctx, &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "full", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Equals: "dog", + }, + { + Attribute: "workload.kubernetes.namespace", + Equals: "default", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example/{{user.name}}/{{ workload.kubernetes.namespace }}/{{ workload.kubernetes.service_account }}", + Hint: "Wow - what a lovely hint, {{user.name}}!", + }, + }, + }) + require.NoError(t, err) + + workloadAttrs := func(f func(attrs *workloadidentityv1pb.WorkloadAttrs)) *workloadidentityv1pb.WorkloadAttrs { + attrs := &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ + Attested: true, + Namespace: "default", + PodName: "test", + ServiceAccount: "bar", + }, + } + if f != nil { + f(attrs) + } + return attrs + } + tests := []struct { + name string + client *authclient.Client + req *workloadidentityv1pb.IssueWorkloadIdentityRequest + requireErr require.ErrorAssertionFunc + assert func(*testing.T, *workloadidentityv1pb.IssueWorkloadIdentityResponse) + }{ + { + name: "jwt svid", + client: wilcardAccessClient, + req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{ + Name: full.GetMetadata().GetName(), + Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{ + JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{ + Audiences: []string{"example.com", "test.example.com"}, + }, + }, + WorkloadAttrs: workloadAttrs(nil), + }, + requireErr: require.NoError, + assert: func(t *testing.T, res *workloadidentityv1pb.IssueWorkloadIdentityResponse) { + cred := res.Credential + require.NotNil(t, res.Credential) + + wantTTL := time.Hour + wantSPIFFEID := "spiffe://localhost/example/dog/default/bar" + require.Empty(t, cmp.Diff( + cred, + &workloadidentityv1pb.Credential{ + Ttl: durationpb.New(wantTTL), + SpiffeId: wantSPIFFEID, + Hint: "Wow - what a lovely hint, dog!", + WorkloadIdentityName: full.GetMetadata().GetName(), + WorkloadIdentityRevision: full.GetMetadata().GetRevision(), + }, + protocmp.Transform(), + protocmp.IgnoreFields( + &workloadidentityv1pb.Credential{}, + "expires_at", + ), + protocmp.IgnoreOneofs( + &workloadidentityv1pb.Credential{}, + "credential", + ), + )) + // Check expiry makes sense + require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second) + + // Check the JWT + parsed, err := jwt.ParseSigned(cred.GetJwtSvid().GetJwt()) + require.NoError(t, err) + + claims := jwt.Claims{} + err = parsed.Claims(jwtSigner.Public(), &claims) + require.NoError(t, err) + // Check headers + require.Len(t, parsed.Headers, 1) + require.Equal(t, kid, parsed.Headers[0].KeyID) + // Check claims + require.Equal(t, wantSPIFFEID, claims.Subject) + require.NotEmpty(t, claims.ID) + require.Equal(t, jwt.Audience{"example.com", "test.example.com"}, claims.Audience) + require.Equal(t, wantIssuer, claims.Issuer) + require.WithinDuration(t, clock.Now().Add(wantTTL), claims.Expiry.Time(), 5*time.Second) + require.WithinDuration(t, clock.Now(), claims.IssuedAt.Time(), 5*time.Second) + + // Check audit log event + evt, ok := eventRecorder.LastEvent().(*events.SPIFFESVIDIssued) + require.True(t, ok) + require.NotEmpty(t, evt.ConnectionMetadata.RemoteAddr) + require.Equal(t, claims.ID, evt.JTI) + require.Equal(t, claims.ID, cred.GetJwtSvid().GetJti()) + require.Empty(t, cmp.Diff( + evt, + &events.SPIFFESVIDIssued{ + Metadata: events.Metadata{ + Type: libevents.SPIFFESVIDIssuedEvent, + Code: libevents.SPIFFESVIDIssuedSuccessCode, + }, + UserMetadata: events.UserMetadata{ + User: wildcardAccess.GetName(), + UserKind: events.UserKind_USER_KIND_HUMAN, + }, + SPIFFEID: "spiffe://localhost/example/dog/default/bar", + SVIDType: "jwt", + Hint: "Wow - what a lovely hint, dog!", + WorkloadIdentity: full.GetMetadata().GetName(), + WorkloadIdentityRevision: full.GetMetadata().GetRevision(), + }, + cmpopts.IgnoreFields( + events.SPIFFESVIDIssued{}, + "ConnectionMetadata", + "JTI", + ), + )) + }, + }, + { + name: "x509 svid", + client: wilcardAccessClient, + req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{ + Name: full.GetMetadata().GetName(), + Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_X509SvidParams{ + X509SvidParams: &workloadidentityv1pb.X509SVIDParams{ + PublicKey: workloadKeyPubBytes, + }, + }, + WorkloadAttrs: workloadAttrs(nil), + }, + requireErr: require.NoError, + assert: func(t *testing.T, res *workloadidentityv1pb.IssueWorkloadIdentityResponse) { + cred := res.Credential + require.NotNil(t, res.Credential) + + wantSPIFFEID := "spiffe://localhost/example/dog/default/bar" + wantTTL := time.Hour + require.Empty(t, cmp.Diff( + cred, + &workloadidentityv1pb.Credential{ + Ttl: durationpb.New(wantTTL), + SpiffeId: wantSPIFFEID, + Hint: "Wow - what a lovely hint, dog!", + WorkloadIdentityName: full.GetMetadata().GetName(), + WorkloadIdentityRevision: full.GetMetadata().GetRevision(), + }, + protocmp.Transform(), + protocmp.IgnoreFields( + &workloadidentityv1pb.Credential{}, + "expires_at", + ), + protocmp.IgnoreOneofs( + &workloadidentityv1pb.Credential{}, + "credential", + ), + )) + // Check expiry makes sense + require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second) + + // Check the X509 + cert, err := x509.ParseCertificate(cred.GetX509Svid().GetCert()) + require.NoError(t, err) + // Check included public key matches + require.Equal(t, workloadKey.Public(), cert.PublicKey) + // Check cert expiry + require.WithinDuration(t, clock.Now().Add(wantTTL), cert.NotAfter, time.Second) + // Check cert nbf + require.WithinDuration(t, clock.Now().Add(-1*time.Minute), cert.NotBefore, time.Second) + // Check cert TTL + require.Equal(t, cert.NotAfter.Sub(cert.NotBefore), wantTTL+time.Minute) + + // Check against SPIFFE SPEC + // References are to https://github.com/spiffe/spiffe/blob/main/standards/X509-SVID.md + // 2: An X.509 SVID MUST contain exactly one URI SAN, and by extension, exactly one SPIFFE ID + require.Len(t, cert.URIs, 1) + require.Equal(t, wantSPIFFEID, cert.URIs[0].String()) + // 4.1: leaf certificates MUST set the cA field to false. + require.False(t, cert.IsCA) + require.Greater(t, cert.KeyUsage&x509.KeyUsageDigitalSignature, 0) + // 4.3: They MAY set keyEncipherment and/or keyAgreement + require.Greater(t, cert.KeyUsage&x509.KeyUsageKeyEncipherment, 0) + require.Greater(t, cert.KeyUsage&x509.KeyUsageKeyAgreement, 0) + // 4.3: Leaf SVIDs MUST NOT set keyCertSign or cRLSign + require.EqualValues(t, 0, cert.KeyUsage&x509.KeyUsageCertSign) + require.EqualValues(t, 0, cert.KeyUsage&x509.KeyUsageCRLSign) + // 4.4: When included, fields id-kp-serverAuth and id-kp-clientAuth MUST be set. + require.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth) + require.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth) + + // Check cert signature is valid + _, err = cert.Verify(x509.VerifyOptions{ + Roots: spiffeX509CAPool, + CurrentTime: srv.Auth().GetClock().Now(), + }) + require.NoError(t, err) + + // Check audit log event + evt, ok := eventRecorder.LastEvent().(*events.SPIFFESVIDIssued) + require.True(t, ok) + require.NotEmpty(t, evt.ConnectionMetadata.RemoteAddr) + require.Equal(t, cred.GetX509Svid().GetSerialNumber(), evt.SerialNumber) + require.Empty(t, cmp.Diff( + evt, + &events.SPIFFESVIDIssued{ + Metadata: events.Metadata{ + Type: libevents.SPIFFESVIDIssuedEvent, + Code: libevents.SPIFFESVIDIssuedSuccessCode, + }, + UserMetadata: events.UserMetadata{ + User: wildcardAccess.GetName(), + UserKind: events.UserKind_USER_KIND_HUMAN, + }, + SPIFFEID: "spiffe://localhost/example/dog/default/bar", + SVIDType: "x509", + Hint: "Wow - what a lovely hint, dog!", + WorkloadIdentity: full.GetMetadata().GetName(), + WorkloadIdentityRevision: full.GetMetadata().GetRevision(), + }, + cmpopts.IgnoreFields( + events.SPIFFESVIDIssued{}, + "ConnectionMetadata", + "SerialNumber", + ), + )) + }, + }, + { + name: "unauthorized by rules", + client: wilcardAccessClient, + req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{ + Name: full.GetMetadata().GetName(), + Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{ + JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{ + Audiences: []string{"example.com", "test.example.com"}, + }, + }, + WorkloadAttrs: workloadAttrs(func(attrs *workloadidentityv1pb.WorkloadAttrs) { + attrs.Kubernetes.Namespace = "not-default" + }), + }, + requireErr: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "unauthorized by labels", + client: specificAccessClient, + req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{ + Name: full.GetMetadata().GetName(), + Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{ + JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{ + Audiences: []string{"example.com", "test.example.com"}, + }, + }, + WorkloadAttrs: workloadAttrs(nil), + }, + requireErr: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsAccessDenied(err)) + }, + }, + { + name: "does not exist", + client: specificAccessClient, + req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{ + Name: "does-not-exist", + Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{ + JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{ + Audiences: []string{"example.com", "test.example.com"}, + }, + }, + WorkloadAttrs: workloadAttrs(nil), + }, + requireErr: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err)) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eventRecorder.Reset() + c := workloadidentityv1pb.NewWorkloadIdentityIssuanceServiceClient( + tt.client.GetConnection(), + ) + res, err := c.IssueWorkloadIdentity(ctx, tt.req) + tt.requireErr(t, err) + if tt.assert != nil { + tt.assert(t, res) + } + }) + } +} + func TestResourceService_CreateWorkloadIdentity(t *testing.T) { t.Parallel() srv, eventRecorder := newTestTLSServer(t) diff --git a/lib/jwt/jwt.go b/lib/jwt/jwt.go index bfdb78dba21c0..f86012a8e5184 100644 --- a/lib/jwt/jwt.go +++ b/lib/jwt/jwt.go @@ -266,6 +266,12 @@ type SignParamsJWTSVID struct { // Issuer is the value that should be included in the `iss` claim of the // created token. Issuer string + + // SetExpiry overrides the expiry time of the token. This causes the value + // of TTL to be ignored. + SetExpiry time.Time + // SetIssuedAt overrides the issued at time of the token. + SetIssuedAt time.Time } // SignJWTSVID signs a JWT SVID token. @@ -297,6 +303,12 @@ func (k *Key) SignJWTSVID(p SignParamsJWTSVID) (string, error) { // understand OIDC. Issuer: p.Issuer, } + if !p.SetIssuedAt.IsZero() { + claims.IssuedAt = jwt.NewNumericDate(p.SetIssuedAt) + } + if !p.SetExpiry.IsZero() { + claims.Expiry = jwt.NewNumericDate(p.SetExpiry) + } // > 2.2. Key ID: // >The kid header is optional.