diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go
index 43fbcbfdc4e62..f233963fea0fe 100644
--- a/lib/auth/grpcserver.go
+++ b/lib/auth/grpcserver.go
@@ -5230,6 +5230,23 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
}
workloadidentityv1pb.RegisterWorkloadIdentityResourceServiceServer(server, workloadIdentityResourceService)
+ clusterName, err := cfg.AuthServer.GetClusterName()
+ if err != nil {
+ return nil, trace.Wrap(err, "getting cluster name")
+ }
+ workloadIdentityIssuanceService, err := workloadidentityv1.NewIssuanceService(&workloadidentityv1.IssuanceServiceConfig{
+ Authorizer: cfg.Authorizer,
+ Cache: cfg.AuthServer.Cache,
+ Emitter: cfg.Emitter,
+ Clock: cfg.AuthServer.GetClock(),
+ KeyStore: cfg.AuthServer.keyStore,
+ ClusterName: clusterName.GetClusterName(),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err, "creating workload identity issuance service")
+ }
+ workloadidentityv1pb.RegisterWorkloadIdentityIssuanceServiceServer(server, workloadIdentityIssuanceService)
+
dbObjectImportRuleService, err := dbobjectimportrulev1.NewDatabaseObjectImportRuleService(dbobjectimportrulev1.DatabaseObjectImportRuleServiceConfig{
Authorizer: cfg.Authorizer,
Backend: cfg.AuthServer.Services,
diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go
index d4242442ab85f..b04c24e80f190 100644
--- a/lib/auth/helpers.go
+++ b/lib/auth/helpers.go
@@ -1324,11 +1324,38 @@ func CreateUser(ctx context.Context, clt clt, username string, roles ...types.Ro
return created, trace.Wrap(err)
}
+// createUserAndRoleOptions is a set of options for CreateUserAndRole
+type createUserAndRoleOptions struct {
+ mutateUser []func(user types.User)
+ mutateRole []func(role types.Role)
+}
+
+// CreateUserAndRoleOption is a functional option for CreateUserAndRole
+type CreateUserAndRoleOption func(*createUserAndRoleOptions)
+
+// WithUserMutator sets a function that will be called to mutate the user before it is created
+func WithUserMutator(mutate ...func(user types.User)) CreateUserAndRoleOption {
+ return func(o *createUserAndRoleOptions) {
+ o.mutateUser = append(o.mutateUser, mutate...)
+ }
+}
+
+// WithRoleMutator sets a function that will be called to mutate the role before it is created
+func WithRoleMutator(mutate ...func(role types.Role)) CreateUserAndRoleOption {
+ return func(o *createUserAndRoleOptions) {
+ o.mutateRole = append(o.mutateRole, mutate...)
+ }
+}
+
// CreateUserAndRole creates user and role and assigns role to a user, used in tests
// If allowRules is nil, the role has admin privileges.
// If allowRules is not-nil, then the rules associated with the role will be
// replaced with those specified.
-func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRules []types.Rule) (types.User, types.Role, error) {
+func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRules []types.Rule, opts ...CreateUserAndRoleOption) (types.User, types.Role, error) {
+ o := createUserAndRoleOptions{}
+ for _, opt := range opts {
+ opt(&o)
+ }
ctx := context.TODO()
user, err := types.NewUser(username)
if err != nil {
@@ -1340,6 +1367,9 @@ func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRu
if allowRules != nil {
role.SetRules(types.Allow, allowRules)
}
+ for _, mutate := range o.mutateRole {
+ mutate(role)
+ }
upsertedRole, err := clt.UpsertRole(ctx, role)
if err != nil {
@@ -1347,6 +1377,9 @@ func CreateUserAndRole(clt clt, username string, allowedLogins []string, allowRu
}
user.AddRole(upsertedRole.GetName())
+ for _, mutate := range o.mutateUser {
+ mutate(user)
+ }
created, err := clt.UpsertUser(ctx, user)
if err != nil {
return nil, nil, trace.Wrap(err)
diff --git a/lib/auth/machineid/workloadidentityv1/experiment/experiment.go b/lib/auth/machineid/workloadidentityv1/experiment/experiment.go
new file mode 100644
index 0000000000000..fafe51ea83d1f
--- /dev/null
+++ b/lib/auth/machineid/workloadidentityv1/experiment/experiment.go
@@ -0,0 +1,41 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package experiment
+
+import (
+ "os"
+ "sync"
+)
+
+var mu sync.Mutex
+
+var experimentEnabled = os.Getenv("TELEPORT_WORKLOAD_IDENTITY_UX_EXPERIMENT") == "1"
+
+// Enabled returns true if the workload identity UX experiment is
+// enabled.
+func Enabled() bool {
+ mu.Lock()
+ defer mu.Unlock()
+ return experimentEnabled
+}
+
+// SetEnabled sets the experiment enabled flag.
+func SetEnabled(enabled bool) {
+ mu.Lock()
+ defer mu.Unlock()
+ experimentEnabled = enabled
+}
diff --git a/lib/auth/machineid/workloadidentityv1/issuer_service.go b/lib/auth/machineid/workloadidentityv1/issuer_service.go
new file mode 100644
index 0000000000000..70a7fa1197974
--- /dev/null
+++ b/lib/auth/machineid/workloadidentityv1/issuer_service.go
@@ -0,0 +1,608 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package workloadidentityv1
+
+import (
+ "context"
+ "crypto"
+ "crypto/rand"
+ "crypto/x509"
+ "log/slog"
+ "math/big"
+ "net/url"
+ "regexp"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/spiffe/go-spiffe/v2/spiffeid"
+ "go.opentelemetry.io/otel"
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/types/known/durationpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/gravitational/teleport"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/observability/tracing"
+ "github.com/gravitational/teleport/api/types"
+ apievents "github.com/gravitational/teleport/api/types/events"
+ "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment"
+ "github.com/gravitational/teleport/lib/authz"
+ "github.com/gravitational/teleport/lib/events"
+ "github.com/gravitational/teleport/lib/jwt"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/teleport/lib/utils/oidc"
+)
+
+var tracer = otel.Tracer("github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1")
+
+// KeyStorer is an interface that provides methods to retrieve keys and
+// certificates from the backend.
+type KeyStorer interface {
+ GetTLSCertAndSigner(ctx context.Context, ca types.CertAuthority) ([]byte, crypto.Signer, error)
+ GetJWTSigner(ctx context.Context, ca types.CertAuthority) (crypto.Signer, error)
+}
+
+type issuerCache interface {
+ workloadIdentityReader
+ GetProxies() ([]types.Server, error)
+ GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error)
+}
+
+// IssuanceServiceConfig holds configuration options for the IssuanceService.
+type IssuanceServiceConfig struct {
+ Authorizer authz.Authorizer
+ Cache issuerCache
+ Clock clockwork.Clock
+ Emitter apievents.Emitter
+ Logger *slog.Logger
+ KeyStore KeyStorer
+
+ ClusterName string
+}
+
+// IssuanceService is the gRPC service for managing workload identity resources.
+// It implements the workloadidentityv1pb.WorkloadIdentityIssuanceServiceServer.
+type IssuanceService struct {
+ workloadidentityv1pb.UnimplementedWorkloadIdentityIssuanceServiceServer
+
+ authorizer authz.Authorizer
+ cache issuerCache
+ clock clockwork.Clock
+ emitter apievents.Emitter
+ logger *slog.Logger
+ keyStore KeyStorer
+
+ clusterName string
+}
+
+// NewIssuanceService returns a new instance of the IssuanceService.
+func NewIssuanceService(cfg *IssuanceServiceConfig) (*IssuanceService, error) {
+ switch {
+ case cfg.Cache == nil:
+ return nil, trace.BadParameter("cache service is required")
+ case cfg.Authorizer == nil:
+ return nil, trace.BadParameter("authorizer is required")
+ case cfg.Emitter == nil:
+ return nil, trace.BadParameter("emitter is required")
+ case cfg.KeyStore == nil:
+ return nil, trace.BadParameter("key store is required")
+ case cfg.ClusterName == "":
+ return nil, trace.BadParameter("cluster name is required")
+ }
+
+ if cfg.Logger == nil {
+ cfg.Logger = slog.With(teleport.ComponentKey, "workload_identity_issuance.service")
+ }
+ if cfg.Clock == nil {
+ cfg.Clock = clockwork.NewRealClock()
+ }
+ return &IssuanceService{
+ authorizer: cfg.Authorizer,
+ cache: cfg.Cache,
+ clock: cfg.Clock,
+ emitter: cfg.Emitter,
+ logger: cfg.Logger,
+ keyStore: cfg.KeyStore,
+ clusterName: cfg.ClusterName,
+ }, nil
+}
+
+// getFieldStringValue returns a string value from the given attribute set.
+// The attribute is specified as a dot-separated path to the field in the
+// attribute set.
+//
+// The specified attribute must be a string field. If the attribute is not
+// found, an error is returned.
+//
+// TODO(noah): This function will be replaced by the Teleport predicate language
+// in a coming PR.
+func getFieldStringValue(attrs *workloadidentityv1pb.Attrs, attr string) (string, error) {
+ attrParts := strings.Split(attr, ".")
+ message := attrs.ProtoReflect()
+ // TODO(noah): Improve errors by including the fully qualified attribute
+ // (e.g add up the parts of the attribute path processed thus far)
+ for i, part := range attrParts {
+ fieldDesc := message.Descriptor().Fields().ByTextName(part)
+ if fieldDesc == nil {
+ return "", trace.NotFound("attribute %q not found", part)
+ }
+ // We expect the final key to point to a string field - otherwise - we
+ // return an error.
+ if i == len(attrParts)-1 {
+ if !slices.Contains([]protoreflect.Kind{
+ protoreflect.StringKind,
+ protoreflect.BoolKind,
+ protoreflect.Int32Kind,
+ protoreflect.Int64Kind,
+ protoreflect.Uint64Kind,
+ protoreflect.Uint32Kind,
+ }, fieldDesc.Kind()) {
+ return "", trace.BadParameter("attribute %q of type %q cannot be converted to string", part, fieldDesc.Kind())
+ }
+ return message.Get(fieldDesc).String(), nil
+ }
+ // If we're not processing the final key part, we expect this to point
+ // to a message that we can further explore.
+ if fieldDesc.Kind() != protoreflect.MessageKind {
+ return "", trace.BadParameter("attribute %q is not a message", part)
+ }
+ message = message.Get(fieldDesc).Message()
+ }
+ return "", nil
+}
+
+// templateString takes a given input string and replaces any values within
+// {{ }} with values from the attribute set.
+//
+// If the specified value is not found in the attribute set, an error is
+// returned.
+//
+// TODO(noah): In a coming PR, this will be replaced by evaluating the values
+// within the handlebars as expressions.
+func templateString(in string, attrs *workloadidentityv1pb.Attrs) (string, error) {
+ re := regexp.MustCompile(`\{\{([^{}]+?)\}\}`)
+ matches := re.FindAllStringSubmatch(in, -1)
+
+ for _, match := range matches {
+ attrKey := strings.TrimSpace(match[1])
+ value, err := getFieldStringValue(attrs, attrKey)
+ if err != nil {
+ return "", trace.Wrap(err, "fetching attribute value for %q", attrKey)
+ }
+ // We want to have an implicit rule here that if an attribute is
+ // included in the template, but is not set, we should refuse to issue
+ // the credential.
+ if value == "" {
+ return "", trace.NotFound("attribute %q unset", attrKey)
+ }
+ in = strings.Replace(in, match[0], value, 1)
+ }
+
+ return in, nil
+}
+
+func evaluateRules(
+ wi *workloadidentityv1pb.WorkloadIdentity,
+ attrs *workloadidentityv1pb.Attrs,
+) error {
+ if len(wi.GetSpec().GetRules().GetAllow()) == 0 {
+ return nil
+ }
+ruleLoop:
+ for _, rule := range wi.GetSpec().GetRules().GetAllow() {
+ for _, condition := range rule.GetConditions() {
+ val, err := getFieldStringValue(attrs, condition.Attribute)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if val != condition.Equals {
+ continue ruleLoop
+ }
+ }
+ return nil
+ }
+ // TODO: Eventually, we'll need to work support for deny rules into here.
+ return trace.AccessDenied("no matching rule found")
+}
+
+func (s *IssuanceService) deriveAttrs(
+ authzCtx *authz.Context,
+ workloadAttrs *workloadidentityv1pb.WorkloadAttrs,
+) (*workloadidentityv1pb.Attrs, error) {
+ attrs := &workloadidentityv1pb.Attrs{
+ Workload: workloadAttrs,
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: authzCtx.Identity.GetIdentity().Username,
+ IsBot: authzCtx.Identity.GetIdentity().BotName != "",
+ BotName: authzCtx.Identity.GetIdentity().BotName,
+ Labels: authzCtx.User.GetAllLabels(),
+ },
+ }
+
+ return attrs, nil
+}
+
+var defaultMaxTTL = 24 * time.Hour
+
+func (s *IssuanceService) IssueWorkloadIdentity(
+ ctx context.Context,
+ req *workloadidentityv1pb.IssueWorkloadIdentityRequest,
+) (*workloadidentityv1pb.IssueWorkloadIdentityResponse, error) {
+ if !experiment.Enabled() {
+ return nil, trace.AccessDenied("workload identity issuance experiment is disabled")
+ }
+
+ authCtx, err := s.authorizer.Authorize(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ switch {
+ case req.GetName() == "":
+ return nil, trace.BadParameter("name: is required")
+ case req.GetCredential() == nil:
+ return nil, trace.BadParameter("at least one credential type must be requested")
+ }
+
+ wi, err := s.cache.GetWorkloadIdentity(ctx, req.GetName())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ // Check the principal has access to the workload identity resource by
+ // virtue of WorkloadIdentityLabels on a role.
+ if err := authCtx.Checker.CheckAccess(
+ types.Resource153ToResourceWithLabels(wi),
+ services.AccessState{},
+ ); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ attrs, err := s.deriveAttrs(authCtx, req.GetWorkloadAttrs())
+ if err != nil {
+ return nil, trace.Wrap(err, "deriving attributes")
+ }
+ // Evaluate any rules explicitly configured by the user
+ if err := evaluateRules(wi, attrs); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ // Perform any templating
+ spiffeIDPath, err := templateString(wi.GetSpec().GetSpiffe().GetId(), attrs)
+ if err != nil {
+ return nil, trace.Wrap(err, "templating spec.spiffe.id")
+ }
+ spiffeID, err := spiffeid.FromURI(&url.URL{
+ Scheme: "spiffe",
+ Host: s.clusterName,
+ Path: spiffeIDPath,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err, "creating SPIFFE ID")
+ }
+
+ hint, err := templateString(wi.GetSpec().GetSpiffe().GetHint(), attrs)
+ if err != nil {
+ return nil, trace.Wrap(err, "templating spec.spiffe.hint")
+ }
+
+ // TODO(noah): Add more sophisticated control of the TTL.
+ ttl := time.Hour
+ if req.RequestedTtl != nil && req.RequestedTtl.AsDuration() != 0 {
+ ttl = req.RequestedTtl.AsDuration()
+ if ttl > defaultMaxTTL {
+ ttl = defaultMaxTTL
+ }
+ }
+
+ now := s.clock.Now()
+ notBefore := now.Add(-1 * time.Minute)
+ notAfter := now.Add(ttl)
+
+ // Prepare event
+ evt := &apievents.SPIFFESVIDIssued{
+ Metadata: apievents.Metadata{
+ Type: events.SPIFFESVIDIssuedEvent,
+ Code: events.SPIFFESVIDIssuedSuccessCode,
+ },
+ UserMetadata: authz.ClientUserMetadata(ctx),
+ ConnectionMetadata: authz.ConnectionMetadata(ctx),
+ SPIFFEID: spiffeID.String(),
+ Hint: hint,
+ WorkloadIdentity: wi.GetMetadata().GetName(),
+ WorkloadIdentityRevision: wi.GetMetadata().GetRevision(),
+ }
+ cred := &workloadidentityv1pb.Credential{
+ WorkloadIdentityName: wi.GetMetadata().GetName(),
+ WorkloadIdentityRevision: wi.GetMetadata().GetRevision(),
+
+ SpiffeId: spiffeID.String(),
+ Hint: hint,
+
+ ExpiresAt: timestamppb.New(notAfter),
+ Ttl: durationpb.New(ttl),
+ }
+
+ switch v := req.GetCredential().(type) {
+ case *workloadidentityv1pb.IssueWorkloadIdentityRequest_X509SvidParams:
+ evt.SVIDType = "x509"
+ certDer, certSerial, err := s.issueX509SVID(
+ ctx,
+ v.X509SvidParams,
+ notBefore,
+ notAfter,
+ spiffeID,
+ )
+ if err != nil {
+ return nil, trace.Wrap(err, "issuing X509 SVID")
+ }
+ serialStr := serialString(certSerial)
+ cred.Credential = &workloadidentityv1pb.Credential_X509Svid{
+ X509Svid: &workloadidentityv1pb.X509SVIDCredential{
+ Cert: certDer,
+ SerialNumber: serialStr,
+ },
+ }
+ evt.SerialNumber = serialStr
+ case *workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams:
+ evt.SVIDType = "jwt"
+ signedJwt, jti, err := s.issueJWTSVID(
+ ctx,
+ v.JwtSvidParams,
+ now,
+ notAfter,
+ spiffeID,
+ )
+ if err != nil {
+ return nil, trace.Wrap(err, "issuing JWT SVID")
+ }
+ cred.Credential = &workloadidentityv1pb.Credential_JwtSvid{
+ JwtSvid: &workloadidentityv1pb.JWTSVIDCredential{
+ Jwt: signedJwt,
+ Jti: jti,
+ },
+ }
+ evt.JTI = jti
+ default:
+ return nil, trace.BadParameter("credential: unknown type %T", req.GetCredential())
+ }
+
+ if err := s.emitter.EmitAuditEvent(ctx, evt); err != nil {
+ s.logger.WarnContext(
+ ctx,
+ "failed to emit audit event for SVID issuance",
+ "error", err,
+ "event", evt,
+ )
+ }
+
+ return &workloadidentityv1pb.IssueWorkloadIdentityResponse{
+ Credential: cred,
+ }, nil
+}
+
+func generateCertSerial() (*big.Int, error) {
+ serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ return rand.Int(rand.Reader, serialNumberLimit)
+}
+
+func x509Template(
+ serialNumber *big.Int,
+ notBefore time.Time,
+ notAfter time.Time,
+ spiffeID spiffeid.ID,
+) *x509.Certificate {
+ return &x509.Certificate{
+ SerialNumber: serialNumber,
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ // SPEC(X509-SVID) 4.3. Key Usage:
+ // - Leaf SVIDs MUST NOT set keyCertSign or cRLSign.
+ // - Leaf SVIDs MUST set digitalSignature
+ // - They MAY set keyEncipherment and/or keyAgreement;
+ KeyUsage: x509.KeyUsageDigitalSignature |
+ x509.KeyUsageKeyEncipherment |
+ x509.KeyUsageKeyAgreement,
+ // SPEC(X509-SVID) 4.4. Extended Key Usage:
+ // - Leaf SVIDs SHOULD include this extension, and it MAY be marked as critical.
+ // - When included, fields id-kp-serverAuth and id-kp-clientAuth MUST be set.
+ ExtKeyUsage: []x509.ExtKeyUsage{
+ x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth,
+ },
+ // SPEC(X509-SVID) 4.1. Basic Constraints:
+ // - leaf certificates MUST set the cA field to false
+ BasicConstraintsValid: true,
+ IsCA: false,
+
+ // SPEC(X509-SVID) 2. SPIFFE ID:
+ // - The corresponding SPIFFE ID is set as a URI type in the Subject Alternative Name extension
+ // - An X.509 SVID MUST contain exactly one URI SAN, and by extension, exactly one SPIFFE ID.
+ // - An X.509 SVID MAY contain any number of other SAN field types, including DNS SANs.
+ URIs: []*url.URL{spiffeID.URL()},
+ }
+}
+
+func (s *IssuanceService) getX509CA(
+ ctx context.Context,
+) (_ *tlsca.CertAuthority, err error) {
+ ctx, span := tracer.Start(ctx, "IssuanceService/getX509CA")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ ca, err := s.cache.GetCertAuthority(ctx, types.CertAuthID{
+ Type: types.SPIFFECA,
+ DomainName: s.clusterName,
+ }, true)
+ tlsCert, tlsSigner, err := s.keyStore.GetTLSCertAndSigner(ctx, ca)
+ if err != nil {
+ return nil, trace.Wrap(err, "getting CA cert and key")
+ }
+ tlsCA, err := tlsca.FromCertAndSigner(tlsCert, tlsSigner)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return tlsCA, nil
+}
+
+func (s *IssuanceService) issueX509SVID(
+ ctx context.Context,
+ params *workloadidentityv1pb.X509SVIDParams,
+ notBefore time.Time,
+ notAfter time.Time,
+ spiffeID spiffeid.ID,
+) (_ []byte, _ *big.Int, err error) {
+ ctx, span := tracer.Start(ctx, "IssuanceService/issueX509SVID")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ switch {
+ case params == nil:
+ return nil, nil, trace.BadParameter("x509_svid_params: is required")
+ case len(params.PublicKey) == 0:
+ return nil, nil, trace.BadParameter("x509_svid_params.public_key: is required")
+ }
+
+ pubKey, err := x509.ParsePKIXPublicKey(params.PublicKey)
+ if err != nil {
+ return nil, nil, trace.Wrap(err, "parsing public key")
+ }
+
+ certSerial, err := generateCertSerial()
+ if err != nil {
+ return nil, nil, trace.Wrap(err, "generating certificate serial")
+ }
+ template := x509Template(certSerial, notBefore, notAfter, spiffeID)
+
+ ca, err := s.getX509CA(ctx)
+ if err != nil {
+ return nil, nil, trace.Wrap(err, "fetching CA to sign X509 SVID")
+ }
+ certBytes, err := x509.CreateCertificate(
+ rand.Reader, template, ca.Cert, pubKey, ca.Signer,
+ )
+ if err != nil {
+ return nil, nil, trace.Wrap(err)
+ }
+
+ return certBytes, certSerial, nil
+}
+
+const jtiLength = 16
+
+func (s *IssuanceService) getJWTIssuerKey(
+ ctx context.Context,
+) (_ *jwt.Key, err error) {
+ ctx, span := tracer.Start(ctx, "IssuanceService/getJWTIssuerKey")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ ca, err := s.cache.GetCertAuthority(ctx, types.CertAuthID{
+ Type: types.SPIFFECA,
+ DomainName: s.clusterName,
+ }, true)
+ if err != nil {
+ return nil, trace.Wrap(err, "getting SPIFFE CA")
+ }
+
+ jwtSigner, err := s.keyStore.GetJWTSigner(ctx, ca)
+ if err != nil {
+ return nil, trace.Wrap(err, "getting JWT signer")
+ }
+
+ jwtKey, err := services.GetJWTSigner(
+ jwtSigner, s.clusterName, s.clock,
+ )
+ if err != nil {
+ return nil, trace.Wrap(err, "creating JWT signer")
+ }
+ return jwtKey, nil
+}
+
+func (s *IssuanceService) issueJWTSVID(
+ ctx context.Context,
+ params *workloadidentityv1pb.JWTSVIDParams,
+ now time.Time,
+ notAfter time.Time,
+ spiffeID spiffeid.ID,
+) (_ string, _ string, err error) {
+ ctx, span := tracer.Start(ctx, "IssuanceService/issueJWTSVID")
+ defer func() { tracing.EndSpan(span, err) }()
+
+ switch {
+ case params == nil:
+ return "", "", trace.BadParameter("jwt_svid_params: is required")
+ case len(params.Audiences) == 0:
+ return "", "", trace.BadParameter("jwt_svid_params.audiences: at least one audience should be specified")
+ }
+
+ jti, err := utils.CryptoRandomHex(jtiLength)
+ if err != nil {
+ return "", "", trace.Wrap(err, "generating JTI")
+ }
+
+ key, err := s.getJWTIssuerKey(ctx)
+ if err != nil {
+ return "", "", trace.Wrap(err, "getting JWT issuer key")
+ }
+
+ // Determine the public address of the proxy for inclusion in the JWT as
+ // the issuer for purposes of OIDC compatibility.
+ issuer, err := oidc.IssuerForCluster(ctx, s.cache, "/workload-identity")
+ if err != nil {
+ return "", "", trace.Wrap(err, "determining issuer URI")
+ }
+
+ signed, err := key.SignJWTSVID(jwt.SignParamsJWTSVID{
+ Audiences: params.Audiences,
+ SPIFFEID: spiffeID,
+ JTI: jti,
+ Issuer: issuer,
+
+ SetIssuedAt: now,
+ SetExpiry: notAfter,
+ })
+ if err != nil {
+ return "", "", trace.Wrap(err, "signing jwt")
+ }
+
+ return signed, jti, nil
+}
+
+func (s *IssuanceService) IssueWorkloadIdentities(
+ ctx context.Context,
+ req *workloadidentityv1pb.IssueWorkloadIdentitiesRequest,
+) (*workloadidentityv1pb.IssueWorkloadIdentitiesResponse, error) {
+ // TODO(noah): Coming to a PR near you soon!
+ return nil, trace.NotImplemented("not implemented")
+}
+
+func serialString(serial *big.Int) string {
+ hex := serial.Text(16)
+ if len(hex)%2 == 1 {
+ hex = "0" + hex
+ }
+
+ out := strings.Builder{}
+ for i := 0; i < len(hex); i += 2 {
+ if i != 0 {
+ out.WriteString(":")
+ }
+ out.WriteString(hex[i : i+2])
+ }
+ return out.String()
+}
diff --git a/lib/auth/machineid/workloadidentityv1/issuer_service_test.go b/lib/auth/machineid/workloadidentityv1/issuer_service_test.go
new file mode 100644
index 0000000000000..bf18594416609
--- /dev/null
+++ b/lib/auth/machineid/workloadidentityv1/issuer_service_test.go
@@ -0,0 +1,223 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package workloadidentityv1
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
+ workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
+ "github.com/gravitational/teleport/api/types"
+)
+
+func Test_getFieldStringValue(t *testing.T) {
+ tests := []struct {
+ name string
+ in *workloadidentityv1pb.Attrs
+ attr string
+ want string
+ requireErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "success",
+ in: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ },
+ attr: "user.name",
+ want: "jeff",
+ requireErr: require.NoError,
+ },
+ {
+ name: "bool",
+ in: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ Workload: &workloadidentityv1pb.WorkloadAttrs{
+ Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
+ Attested: true,
+ },
+ },
+ },
+ attr: "workload.unix.attested",
+ want: "true",
+ requireErr: require.NoError,
+ },
+ {
+ name: "int32",
+ in: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ Workload: &workloadidentityv1pb.WorkloadAttrs{
+ Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
+ Pid: 123,
+ },
+ },
+ },
+ attr: "workload.unix.pid",
+ want: "123",
+ requireErr: require.NoError,
+ },
+ {
+ name: "uint32",
+ in: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ Workload: &workloadidentityv1pb.WorkloadAttrs{
+ Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
+ Gid: 123,
+ },
+ },
+ },
+ attr: "workload.unix.gid",
+ want: "123",
+ requireErr: require.NoError,
+ },
+ {
+ name: "non-string final field",
+ in: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "user",
+ },
+ },
+ attr: "user",
+ requireErr: func(t require.TestingT, err error, i ...interface{}) {
+ require.ErrorContains(t, err, "attribute \"user\" of type \"message\" cannot be converted to string")
+ },
+ },
+ {
+ // We mostly just want this to not panic.
+ name: "nil root",
+ in: nil,
+ attr: "user.name",
+ want: "",
+ requireErr: require.NoError,
+ },
+ {
+ // We mostly just want this to not panic.
+ name: "nil submessage",
+ in: &workloadidentityv1pb.Attrs{
+ User: nil,
+ },
+ attr: "user.name",
+ want: "",
+ requireErr: require.NoError,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, gotErr := getFieldStringValue(tt.in, tt.attr)
+ tt.requireErr(t, gotErr)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_templateString(t *testing.T) {
+ tests := []struct {
+ name string
+ in string
+ want string
+ attrs *workloadidentityv1pb.Attrs
+ requireErr require.ErrorAssertionFunc
+ }{
+ {
+ name: "success mixed",
+ in: "hello{{user.name}}.{{user.name}} {{ workload.kubernetes.pod_name }}//{{ workload.kubernetes.namespace}}",
+ want: "hellojeff.jeff pod1//default",
+ attrs: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ Workload: &workloadidentityv1pb.WorkloadAttrs{
+ Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
+ PodName: "pod1",
+ Namespace: "default",
+ },
+ },
+ },
+ requireErr: require.NoError,
+ },
+ {
+ name: "success with spaces",
+ in: "hello {{user.name}}",
+ want: "hello jeff",
+ attrs: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ },
+ requireErr: require.NoError,
+ },
+ {
+ name: "fail due to unset",
+ in: "hello {{workload.kubernetes.pod_name}}",
+ attrs: &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "jeff",
+ },
+ },
+ requireErr: func(t require.TestingT, err error, i ...interface{}) {
+ require.ErrorContains(t, err, "attribute \"workload.kubernetes.pod_name\" unset")
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, gotErr := templateString(tt.in, tt.attrs)
+ tt.requireErr(t, gotErr)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_evaluateRules(t *testing.T) {
+ attrs := &workloadidentityv1pb.Attrs{
+ User: &workloadidentityv1pb.UserAttrs{
+ Name: "foo",
+ },
+ }
+ wi := &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "test",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Equals: "foo",
+ },
+ },
+ },
+ },
+ },
+ },
+ }
+ err := evaluateRules(wi, attrs)
+ require.NoError(t, err)
+}
diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
index 1c0601a34dd54..b911885031995 100644
--- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
+++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go
@@ -18,6 +18,8 @@ package workloadidentityv1_test
import (
"context"
+ "crypto/rsa"
+ "crypto/x509"
"errors"
"fmt"
"net"
@@ -26,6 +28,7 @@ import (
"testing"
"time"
+ "github.com/go-jose/go-jose/v3/jwt"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gravitational/trace"
@@ -33,6 +36,7 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
+ "google.golang.org/protobuf/types/known/durationpb"
headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
@@ -40,9 +44,13 @@ import (
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/auth/machineid/workloadidentityv1/experiment"
+ "github.com/gravitational/teleport/lib/auth/native"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
+ libjwt "github.com/gravitational/teleport/lib/jwt"
"github.com/gravitational/teleport/lib/modules"
+ "github.com/gravitational/teleport/lib/services"
)
func TestMain(m *testing.M) {
@@ -74,6 +82,396 @@ func newTestTLSServer(t testing.TB) (*auth.TestTLSServer, *eventstest.MockRecord
return srv, emitter
}
+func TestIssueWorkloadIdentity(t *testing.T) {
+ experimentStatus := experiment.Enabled()
+ defer experiment.SetEnabled(experimentStatus)
+ experiment.SetEnabled(true)
+
+ srv, eventRecorder := newTestTLSServer(t)
+ ctx := context.Background()
+ clock := srv.Auth().GetClock()
+
+ // Upsert a fake proxy to ensure we have a public address to use for the
+ // issuer.
+ proxy, err := types.NewServer("proxy", types.KindProxy, types.ServerSpecV2{
+ PublicAddrs: []string{"teleport.example.com"},
+ })
+ require.NoError(t, err)
+ err = srv.Auth().UpsertProxy(ctx, proxy)
+ require.NoError(t, err)
+ wantIssuer := "https://teleport.example.com/workload-identity"
+
+ // Fetch X509 SPIFFE CA for validation of signature later
+ spiffeX509CA, err := srv.Auth().GetCertAuthority(ctx, types.CertAuthID{
+ Type: types.SPIFFECA,
+ DomainName: srv.ClusterName(),
+ }, false)
+ require.NoError(t, err)
+ spiffeX509CAPool, err := services.CertPool(spiffeX509CA)
+ require.NoError(t, err)
+ // Fetch JWT CA to validate JWTs
+ jwtCA, err := srv.Auth().GetCertAuthority(ctx, types.CertAuthID{
+ Type: types.SPIFFECA,
+ DomainName: "localhost",
+ }, true)
+ require.NoError(t, err)
+ jwtSigner, err := srv.Auth().GetKeyStore().GetJWTSigner(ctx, jwtCA)
+ require.NoError(t, err)
+ kid := libjwt.KeyID(jwtSigner.Public().(*rsa.PublicKey))
+
+ wildcardAccess, _, err := auth.CreateUserAndRole(
+ srv.Auth(),
+ "dog",
+ []string{},
+ []types.Rule{},
+ auth.WithRoleMutator(func(role types.Role) {
+ role.SetWorkloadIdentityLabels(types.Allow, types.Labels{
+ types.Wildcard: []string{types.Wildcard},
+ })
+ }),
+ )
+ require.NoError(t, err)
+ wilcardAccessClient, err := srv.NewClient(auth.TestUser(wildcardAccess.GetName()))
+ require.NoError(t, err)
+
+ specificAccess, _, err := auth.CreateUserAndRole(
+ srv.Auth(),
+ "cat",
+ []string{},
+ []types.Rule{},
+ auth.WithRoleMutator(func(role types.Role) {
+ role.SetWorkloadIdentityLabels(types.Allow, types.Labels{
+ "foo": []string{"bar"},
+ })
+ }),
+ )
+ require.NoError(t, err)
+ specificAccessClient, err := srv.NewClient(auth.TestUser(specificAccess.GetName()))
+ require.NoError(t, err)
+
+ // Generate a keypair to generate x509 SVIDs for.
+ workloadKey, err := native.GenerateRSAPrivateKey()
+ require.NoError(t, err)
+ workloadKeyPubBytes, err := x509.MarshalPKIXPublicKey(workloadKey.Public())
+ require.NoError(t, err)
+
+ // Create some WorkloadIdentity resources
+ full, err := srv.Auth().CreateWorkloadIdentity(ctx, &workloadidentityv1pb.WorkloadIdentity{
+ Kind: types.KindWorkloadIdentity,
+ Version: types.V1,
+ Metadata: &headerv1.Metadata{
+ Name: "full",
+ },
+ Spec: &workloadidentityv1pb.WorkloadIdentitySpec{
+ Rules: &workloadidentityv1pb.WorkloadIdentityRules{
+ Allow: []*workloadidentityv1pb.WorkloadIdentityRule{
+ {
+ Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{
+ {
+ Attribute: "user.name",
+ Equals: "dog",
+ },
+ {
+ Attribute: "workload.kubernetes.namespace",
+ Equals: "default",
+ },
+ },
+ },
+ },
+ },
+ Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{
+ Id: "/example/{{user.name}}/{{ workload.kubernetes.namespace }}/{{ workload.kubernetes.service_account }}",
+ Hint: "Wow - what a lovely hint, {{user.name}}!",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ workloadAttrs := func(f func(attrs *workloadidentityv1pb.WorkloadAttrs)) *workloadidentityv1pb.WorkloadAttrs {
+ attrs := &workloadidentityv1pb.WorkloadAttrs{
+ Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
+ Attested: true,
+ Namespace: "default",
+ PodName: "test",
+ ServiceAccount: "bar",
+ },
+ }
+ if f != nil {
+ f(attrs)
+ }
+ return attrs
+ }
+ tests := []struct {
+ name string
+ client *authclient.Client
+ req *workloadidentityv1pb.IssueWorkloadIdentityRequest
+ requireErr require.ErrorAssertionFunc
+ assert func(*testing.T, *workloadidentityv1pb.IssueWorkloadIdentityResponse)
+ }{
+ {
+ name: "jwt svid",
+ client: wilcardAccessClient,
+ req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{
+ Name: full.GetMetadata().GetName(),
+ Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{
+ JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{
+ Audiences: []string{"example.com", "test.example.com"},
+ },
+ },
+ WorkloadAttrs: workloadAttrs(nil),
+ },
+ requireErr: require.NoError,
+ assert: func(t *testing.T, res *workloadidentityv1pb.IssueWorkloadIdentityResponse) {
+ cred := res.Credential
+ require.NotNil(t, res.Credential)
+
+ wantTTL := time.Hour
+ wantSPIFFEID := "spiffe://localhost/example/dog/default/bar"
+ require.Empty(t, cmp.Diff(
+ cred,
+ &workloadidentityv1pb.Credential{
+ Ttl: durationpb.New(wantTTL),
+ SpiffeId: wantSPIFFEID,
+ Hint: "Wow - what a lovely hint, dog!",
+ WorkloadIdentityName: full.GetMetadata().GetName(),
+ WorkloadIdentityRevision: full.GetMetadata().GetRevision(),
+ },
+ protocmp.Transform(),
+ protocmp.IgnoreFields(
+ &workloadidentityv1pb.Credential{},
+ "expires_at",
+ ),
+ protocmp.IgnoreOneofs(
+ &workloadidentityv1pb.Credential{},
+ "credential",
+ ),
+ ))
+ // Check expiry makes sense
+ require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second)
+
+ // Check the JWT
+ parsed, err := jwt.ParseSigned(cred.GetJwtSvid().GetJwt())
+ require.NoError(t, err)
+
+ claims := jwt.Claims{}
+ err = parsed.Claims(jwtSigner.Public(), &claims)
+ require.NoError(t, err)
+ // Check headers
+ require.Len(t, parsed.Headers, 1)
+ require.Equal(t, kid, parsed.Headers[0].KeyID)
+ // Check claims
+ require.Equal(t, wantSPIFFEID, claims.Subject)
+ require.NotEmpty(t, claims.ID)
+ require.Equal(t, jwt.Audience{"example.com", "test.example.com"}, claims.Audience)
+ require.Equal(t, wantIssuer, claims.Issuer)
+ require.WithinDuration(t, clock.Now().Add(wantTTL), claims.Expiry.Time(), 5*time.Second)
+ require.WithinDuration(t, clock.Now(), claims.IssuedAt.Time(), 5*time.Second)
+
+ // Check audit log event
+ evt, ok := eventRecorder.LastEvent().(*events.SPIFFESVIDIssued)
+ require.True(t, ok)
+ require.NotEmpty(t, evt.ConnectionMetadata.RemoteAddr)
+ require.Equal(t, claims.ID, evt.JTI)
+ require.Equal(t, claims.ID, cred.GetJwtSvid().GetJti())
+ require.Empty(t, cmp.Diff(
+ evt,
+ &events.SPIFFESVIDIssued{
+ Metadata: events.Metadata{
+ Type: libevents.SPIFFESVIDIssuedEvent,
+ Code: libevents.SPIFFESVIDIssuedSuccessCode,
+ },
+ UserMetadata: events.UserMetadata{
+ User: wildcardAccess.GetName(),
+ UserKind: events.UserKind_USER_KIND_HUMAN,
+ },
+ SPIFFEID: "spiffe://localhost/example/dog/default/bar",
+ SVIDType: "jwt",
+ Hint: "Wow - what a lovely hint, dog!",
+ WorkloadIdentity: full.GetMetadata().GetName(),
+ WorkloadIdentityRevision: full.GetMetadata().GetRevision(),
+ },
+ cmpopts.IgnoreFields(
+ events.SPIFFESVIDIssued{},
+ "ConnectionMetadata",
+ "JTI",
+ ),
+ ))
+ },
+ },
+ {
+ name: "x509 svid",
+ client: wilcardAccessClient,
+ req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{
+ Name: full.GetMetadata().GetName(),
+ Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_X509SvidParams{
+ X509SvidParams: &workloadidentityv1pb.X509SVIDParams{
+ PublicKey: workloadKeyPubBytes,
+ },
+ },
+ WorkloadAttrs: workloadAttrs(nil),
+ },
+ requireErr: require.NoError,
+ assert: func(t *testing.T, res *workloadidentityv1pb.IssueWorkloadIdentityResponse) {
+ cred := res.Credential
+ require.NotNil(t, res.Credential)
+
+ wantSPIFFEID := "spiffe://localhost/example/dog/default/bar"
+ wantTTL := time.Hour
+ require.Empty(t, cmp.Diff(
+ cred,
+ &workloadidentityv1pb.Credential{
+ Ttl: durationpb.New(wantTTL),
+ SpiffeId: wantSPIFFEID,
+ Hint: "Wow - what a lovely hint, dog!",
+ WorkloadIdentityName: full.GetMetadata().GetName(),
+ WorkloadIdentityRevision: full.GetMetadata().GetRevision(),
+ },
+ protocmp.Transform(),
+ protocmp.IgnoreFields(
+ &workloadidentityv1pb.Credential{},
+ "expires_at",
+ ),
+ protocmp.IgnoreOneofs(
+ &workloadidentityv1pb.Credential{},
+ "credential",
+ ),
+ ))
+ // Check expiry makes sense
+ require.WithinDuration(t, clock.Now().Add(wantTTL), cred.GetExpiresAt().AsTime(), time.Second)
+
+ // Check the X509
+ cert, err := x509.ParseCertificate(cred.GetX509Svid().GetCert())
+ require.NoError(t, err)
+ // Check included public key matches
+ require.Equal(t, workloadKey.Public(), cert.PublicKey)
+ // Check cert expiry
+ require.WithinDuration(t, clock.Now().Add(wantTTL), cert.NotAfter, time.Second)
+ // Check cert nbf
+ require.WithinDuration(t, clock.Now().Add(-1*time.Minute), cert.NotBefore, time.Second)
+ // Check cert TTL
+ require.Equal(t, cert.NotAfter.Sub(cert.NotBefore), wantTTL+time.Minute)
+
+ // Check against SPIFFE SPEC
+ // References are to https://github.com/spiffe/spiffe/blob/main/standards/X509-SVID.md
+ // 2: An X.509 SVID MUST contain exactly one URI SAN, and by extension, exactly one SPIFFE ID
+ require.Len(t, cert.URIs, 1)
+ require.Equal(t, wantSPIFFEID, cert.URIs[0].String())
+ // 4.1: leaf certificates MUST set the cA field to false.
+ require.False(t, cert.IsCA)
+ require.Greater(t, cert.KeyUsage&x509.KeyUsageDigitalSignature, 0)
+ // 4.3: They MAY set keyEncipherment and/or keyAgreement
+ require.Greater(t, cert.KeyUsage&x509.KeyUsageKeyEncipherment, 0)
+ require.Greater(t, cert.KeyUsage&x509.KeyUsageKeyAgreement, 0)
+ // 4.3: Leaf SVIDs MUST NOT set keyCertSign or cRLSign
+ require.EqualValues(t, 0, cert.KeyUsage&x509.KeyUsageCertSign)
+ require.EqualValues(t, 0, cert.KeyUsage&x509.KeyUsageCRLSign)
+ // 4.4: When included, fields id-kp-serverAuth and id-kp-clientAuth MUST be set.
+ require.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageServerAuth)
+ require.Contains(t, cert.ExtKeyUsage, x509.ExtKeyUsageClientAuth)
+
+ // Check cert signature is valid
+ _, err = cert.Verify(x509.VerifyOptions{
+ Roots: spiffeX509CAPool,
+ CurrentTime: srv.Auth().GetClock().Now(),
+ })
+ require.NoError(t, err)
+
+ // Check audit log event
+ evt, ok := eventRecorder.LastEvent().(*events.SPIFFESVIDIssued)
+ require.True(t, ok)
+ require.NotEmpty(t, evt.ConnectionMetadata.RemoteAddr)
+ require.Equal(t, cred.GetX509Svid().GetSerialNumber(), evt.SerialNumber)
+ require.Empty(t, cmp.Diff(
+ evt,
+ &events.SPIFFESVIDIssued{
+ Metadata: events.Metadata{
+ Type: libevents.SPIFFESVIDIssuedEvent,
+ Code: libevents.SPIFFESVIDIssuedSuccessCode,
+ },
+ UserMetadata: events.UserMetadata{
+ User: wildcardAccess.GetName(),
+ UserKind: events.UserKind_USER_KIND_HUMAN,
+ },
+ SPIFFEID: "spiffe://localhost/example/dog/default/bar",
+ SVIDType: "x509",
+ Hint: "Wow - what a lovely hint, dog!",
+ WorkloadIdentity: full.GetMetadata().GetName(),
+ WorkloadIdentityRevision: full.GetMetadata().GetRevision(),
+ },
+ cmpopts.IgnoreFields(
+ events.SPIFFESVIDIssued{},
+ "ConnectionMetadata",
+ "SerialNumber",
+ ),
+ ))
+ },
+ },
+ {
+ name: "unauthorized by rules",
+ client: wilcardAccessClient,
+ req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{
+ Name: full.GetMetadata().GetName(),
+ Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{
+ JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{
+ Audiences: []string{"example.com", "test.example.com"},
+ },
+ },
+ WorkloadAttrs: workloadAttrs(func(attrs *workloadidentityv1pb.WorkloadAttrs) {
+ attrs.Kubernetes.Namespace = "not-default"
+ }),
+ },
+ requireErr: func(t require.TestingT, err error, i ...interface{}) {
+ require.True(t, trace.IsAccessDenied(err))
+ },
+ },
+ {
+ name: "unauthorized by labels",
+ client: specificAccessClient,
+ req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{
+ Name: full.GetMetadata().GetName(),
+ Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{
+ JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{
+ Audiences: []string{"example.com", "test.example.com"},
+ },
+ },
+ WorkloadAttrs: workloadAttrs(nil),
+ },
+ requireErr: func(t require.TestingT, err error, i ...interface{}) {
+ require.True(t, trace.IsAccessDenied(err))
+ },
+ },
+ {
+ name: "does not exist",
+ client: specificAccessClient,
+ req: &workloadidentityv1pb.IssueWorkloadIdentityRequest{
+ Name: "does-not-exist",
+ Credential: &workloadidentityv1pb.IssueWorkloadIdentityRequest_JwtSvidParams{
+ JwtSvidParams: &workloadidentityv1pb.JWTSVIDParams{
+ Audiences: []string{"example.com", "test.example.com"},
+ },
+ },
+ WorkloadAttrs: workloadAttrs(nil),
+ },
+ requireErr: func(t require.TestingT, err error, i ...interface{}) {
+ require.True(t, trace.IsNotFound(err))
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ eventRecorder.Reset()
+ c := workloadidentityv1pb.NewWorkloadIdentityIssuanceServiceClient(
+ tt.client.GetConnection(),
+ )
+ res, err := c.IssueWorkloadIdentity(ctx, tt.req)
+ tt.requireErr(t, err)
+ if tt.assert != nil {
+ tt.assert(t, res)
+ }
+ })
+ }
+}
+
func TestResourceService_CreateWorkloadIdentity(t *testing.T) {
t.Parallel()
srv, eventRecorder := newTestTLSServer(t)
diff --git a/lib/jwt/jwt.go b/lib/jwt/jwt.go
index bfdb78dba21c0..f86012a8e5184 100644
--- a/lib/jwt/jwt.go
+++ b/lib/jwt/jwt.go
@@ -266,6 +266,12 @@ type SignParamsJWTSVID struct {
// Issuer is the value that should be included in the `iss` claim of the
// created token.
Issuer string
+
+ // SetExpiry overrides the expiry time of the token. This causes the value
+ // of TTL to be ignored.
+ SetExpiry time.Time
+ // SetIssuedAt overrides the issued at time of the token.
+ SetIssuedAt time.Time
}
// SignJWTSVID signs a JWT SVID token.
@@ -297,6 +303,12 @@ func (k *Key) SignJWTSVID(p SignParamsJWTSVID) (string, error) {
// understand OIDC.
Issuer: p.Issuer,
}
+ if !p.SetIssuedAt.IsZero() {
+ claims.IssuedAt = jwt.NewNumericDate(p.SetIssuedAt)
+ }
+ if !p.SetExpiry.IsZero() {
+ claims.Expiry = jwt.NewNumericDate(p.SetExpiry)
+ }
// > 2.2. Key ID:
// >The kid header is optional.