From c6e46e3262c5616efa80402396fa58f6408efea7 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Tue, 10 Dec 2024 10:22:09 +0100 Subject: [PATCH] Workload ID: Add WorkloadIdentity local service and cache config (#49942) * Add WorkloadIdentity store and cache * Update lib/services/local/workload_identity.go Co-authored-by: Edward Dowling * Update lib/services/local/workload_identity.go Co-authored-by: Edward Dowling * Update lib/cache/resource_workload_identity.go Co-authored-by: Edoardo Spadolini --------- Co-authored-by: Edward Dowling Co-authored-by: Edoardo Spadolini --- lib/auth/accesspoint/accesspoint.go | 2 + lib/auth/auth.go | 9 + lib/auth/authclient/api.go | 7 + lib/auth/helpers.go | 1 + lib/auth/init.go | 4 + lib/cache/cache.go | 12 + lib/cache/cache_test.go | 13 + lib/cache/collections.go | 11 + lib/cache/resource_workload_identity.go | 119 +++++++ lib/cache/resource_workload_identity_test.go | 74 ++++ lib/service/service.go | 1 + lib/services/local/events.go | 42 +++ lib/services/local/workload_identity.go | 118 ++++++ lib/services/local/workload_identity_test.go | 355 +++++++++++++++++++ lib/services/workload_identity.go | 122 +++++++ lib/services/workload_identity_test.go | 231 ++++++++++++ 16 files changed, 1121 insertions(+) create mode 100644 lib/cache/resource_workload_identity.go create mode 100644 lib/cache/resource_workload_identity_test.go create mode 100644 lib/services/local/workload_identity.go create mode 100644 lib/services/local/workload_identity_test.go create mode 100644 lib/services/workload_identity.go create mode 100644 lib/services/workload_identity_test.go diff --git a/lib/auth/accesspoint/accesspoint.go b/lib/auth/accesspoint/accesspoint.go index d078d25d87b92..66bf51223990f 100644 --- a/lib/auth/accesspoint/accesspoint.go +++ b/lib/auth/accesspoint/accesspoint.go @@ -103,6 +103,7 @@ type Config struct { Users services.UsersService WebSession types.WebSessionInterface WebToken types.WebTokenInterface + WorkloadIdentity cache.WorkloadIdentityReader DynamicWindowsDesktops services.DynamicWindowsDesktops WindowsDesktops services.WindowsDesktops AutoUpdateService services.AutoUpdateServiceGetter @@ -203,6 +204,7 @@ func NewCache(cfg Config) (*cache.Cache, error) { Users: cfg.Users, WebSession: cfg.WebSession, WebToken: cfg.WebToken, + WorkloadIdentity: cfg.WorkloadIdentity, WindowsDesktops: cfg.WindowsDesktops, DynamicWindowsDesktops: cfg.DynamicWindowsDesktops, ProvisioningStates: cfg.ProvisioningStates, diff --git a/lib/auth/auth.go b/lib/auth/auth.go index dedda21f57012..342e52e821194 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -401,6 +401,13 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { return nil, trace.Wrap(err, "creating GitServer service") } } + if cfg.WorkloadIdentity == nil { + workloadIdentity, err := local.NewWorkloadIdentityService(cfg.Backend) + if err != nil { + return nil, trace.Wrap(err, "creating WorkloadIdentity service") + } + cfg.WorkloadIdentity = workloadIdentity + } if cfg.Logger == nil { cfg.Logger = slog.With(teleport.ComponentKey, teleport.ComponentAuth) } @@ -499,6 +506,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { IdentityCenter: cfg.IdentityCenter, PluginStaticCredentials: cfg.PluginStaticCredentials, GitServers: cfg.GitServers, + WorkloadIdentities: cfg.WorkloadIdentity, } as := Server{ @@ -718,6 +726,7 @@ type Services struct { services.IdentityCenter services.PluginStaticCredentials services.GitServers + services.WorkloadIdentities } // GetWebSession returns existing web session described by req. diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go index 2a9d3095b4137..409e4850e8a97 100644 --- a/lib/auth/authclient/api.go +++ b/lib/auth/authclient/api.go @@ -38,6 +38,7 @@ import ( userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" @@ -1229,6 +1230,12 @@ type Cache interface { // pagination. ListSPIFFEFederations(ctx context.Context, pageSize int, lastToken string) ([]*machineidv1.SPIFFEFederation, string, error) + // GetWorkloadIdentity gets a WorkloadIdentity by name. + GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) + // ListWorkloadIdentities lists all SPIFFE Federations using Google style + // pagination. + ListWorkloadIdentities(ctx context.Context, pageSize int, lastToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) + // ListStaticHostUsers lists static host users. ListStaticHostUsers(ctx context.Context, pageSize int, startKey string) ([]*userprovisioningpb.StaticHostUser, string, error) // GetStaticHostUser returns a static host user by name. diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 71804b4ca0049..58079d4745374 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -360,6 +360,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { SecReports: svces.SecReports, SnowflakeSession: svces.Identity, SPIFFEFederations: svces.SPIFFEFederations, + WorkloadIdentity: svces.WorkloadIdentities, StaticHostUsers: svces.StaticHostUser, Trust: svces.TrustInternal, UserGroups: svces.UserGroups, diff --git a/lib/auth/init.go b/lib/auth/init.go index 61bb8cba0e447..9d86d21c1106f 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -322,6 +322,10 @@ type InitConfig struct { // SPIFFEFederations is a service that manages storing SPIFFE federations. SPIFFEFederations services.SPIFFEFederations + // WorkloadIdentity is the service for storing and retrieving + // WorkloadIdentity resources. + WorkloadIdentity services.WorkloadIdentities + // StaticHostUsers is a service that manages host users that should be // created on SSH nodes. StaticHostUsers services.StaticHostUser diff --git a/lib/cache/cache.go b/lib/cache/cache.go index fcb2a3bf7da5f..b29d9cbd07054 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -201,6 +201,7 @@ func ForAuth(cfg Config) Config { {Kind: types.KindIdentityCenterAccountAssignment}, {Kind: types.KindPluginStaticCredentials}, {Kind: types.KindGitServer}, + {Kind: types.KindWorkloadIdentity}, } cfg.QueueSize = defaults.AuthQueueSize // We don't want to enable partial health for auth cache because auth uses an event stream @@ -556,6 +557,7 @@ type Cache struct { identityCenterCache *local.IdentityCenterService pluginStaticCredentialsCache *local.PluginStaticCredentialsService gitServersCache *local.GitServerService + workloadIdentityCache workloadIdentityCacher // closed indicates that the cache has been closed closed atomic.Bool @@ -738,6 +740,9 @@ type Config struct { SPIFFEFederations SPIFFEFederationReader // StaticHostUsers is the static host user service. StaticHostUsers services.StaticHostUser + // WorkloadIdentity is the upstream Workload Identities service that we're + // caching + WorkloadIdentity WorkloadIdentityReader // Backend is a backend for local cache Backend backend.Backend // MaxRetryPeriod is the maximum period between cache retries on failures @@ -1008,6 +1013,12 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } + workloadIdentityCache, err := local.NewWorkloadIdentityService(config.Backend) + if err != nil { + cancel() + return nil, trace.Wrap(err) + } + staticHostUserCache, err := local.NewStaticHostUserService(config.Backend) if err != nil { cancel() @@ -1094,6 +1105,7 @@ func New(config Config) (*Cache, error) { identityCenterCache: identityCenterCache, pluginStaticCredentialsCache: pluginStaticCredentialsCache, gitServersCache: gitServersCache, + workloadIdentityCache: workloadIdentityCache, Logger: log.WithFields(log.Fields{ teleport.ComponentKey: config.Component, }), diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index af4b4d195bce4..e60f0acac0174 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -143,6 +143,7 @@ type testPack struct { identityCenter services.IdentityCenter pluginStaticCredentials *local.PluginStaticCredentialsService gitServers services.GitServers + workloadIdentity *local.WorkloadIdentityService } // testFuncs are functions to support testing an object in a cache. @@ -365,6 +366,12 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) { } p.spiffeFederations = spiffeFederationsSvc + workloadIdentitySvc, err := local.NewWorkloadIdentityService(p.backend) + if err != nil { + return nil, trace.Wrap(err) + } + p.workloadIdentity = workloadIdentitySvc + databaseObjectsSvc, err := local.NewDatabaseObjectService(p.backend) if err != nil { return nil, trace.Wrap(err) @@ -470,6 +477,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption) IdentityCenter: p.identityCenter, PluginStaticCredentials: p.pluginStaticCredentials, GitServers: p.gitServers, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -881,6 +889,7 @@ func TestCompletenessInit(t *testing.T) { StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, ProvisioningStates: p.provisioningStates, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, IdentityCenter: p.identityCenter, PluginStaticCredentials: p.pluginStaticCredentials, @@ -969,6 +978,7 @@ func TestCompletenessReset(t *testing.T) { ProvisioningStates: p.provisioningStates, IdentityCenter: p.identityCenter, PluginStaticCredentials: p.pluginStaticCredentials, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, GitServers: p.gitServers, @@ -1181,6 +1191,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) { ProvisioningStates: p.provisioningStates, IdentityCenter: p.identityCenter, PluginStaticCredentials: p.pluginStaticCredentials, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, neverOK: true, // ensure reads are never healthy @@ -1278,6 +1289,7 @@ func initStrategy(t *testing.T) { ProvisioningStates: p.provisioningStates, IdentityCenter: p.identityCenter, PluginStaticCredentials: p.pluginStaticCredentials, + WorkloadIdentity: p.workloadIdentity, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, GitServers: p.gitServers, @@ -3556,6 +3568,7 @@ func TestCacheWatchKindExistsInEvents(t *testing.T) { types.KindIdentityCenterPrincipalAssignment: types.Resource153ToLegacy(newIdentityCenterPrincipalAssignment("some_principal_assignment")), types.KindPluginStaticCredentials: &types.PluginStaticCredentialsV1{}, types.KindGitServer: &types.ServerV2{}, + types.KindWorkloadIdentity: types.Resource153ToLegacy(newWorkloadIdentity("some_identifier")), } for name, cfg := range cases { diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 2635a0d71ea04..f73f83fcddb83 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -41,6 +41,7 @@ import ( userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v2" userspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/users/v1" usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" @@ -178,6 +179,7 @@ type cacheCollections struct { identityCenterAccountAssignments collectionReader[identityCenterAccountAssignmentGetter] pluginStaticCredentials collectionReader[pluginStaticCredentialsGetter] gitServers collectionReader[services.GitServerGetter] + workloadIdentity collectionReader[WorkloadIdentityReader] } // setupCollections returns a registry of collections. @@ -706,6 +708,15 @@ func setupCollections(c *Cache, watches []types.WatchKind) (*cacheCollections, e watch: watch, } collections.byKind[resourceKind] = collections.spiffeFederations + case types.KindWorkloadIdentity: + if c.Config.WorkloadIdentity == nil { + return nil, trace.BadParameter("missing parameter WorkloadIdentity") + } + collections.workloadIdentity = &genericCollection[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader, workloadIdentityExecutor]{ + cache: c, + watch: watch, + } + collections.byKind[resourceKind] = collections.workloadIdentity case types.KindAutoUpdateConfig: if c.AutoUpdateService == nil { return nil, trace.BadParameter("missing parameter AutoUpdateService") diff --git a/lib/cache/resource_workload_identity.go b/lib/cache/resource_workload_identity.go new file mode 100644 index 0000000000000..75efb50fedbd5 --- /dev/null +++ b/lib/cache/resource_workload_identity.go @@ -0,0 +1,119 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//nolint:unused // Because the executors generate a large amount of false positives. +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +// WorkloadIdentityReader is an interface that defines the methods for getting +// WorkloadIdentity. This is returned as the reader for the WorkloadIdentity +// collection but is also used by the executor to read the full list of +// WorkloadIdentity on initialization. +type WorkloadIdentityReader interface { + ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) + GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) +} + +// workloadIdentityCacher is used for storing and retrieving WorkloadIdentity +// from the cache's local backend. +type workloadIdentityCacher interface { + WorkloadIdentityReader + UpsertWorkloadIdentity(ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity) (*workloadidentityv1pb.WorkloadIdentity, error) + DeleteWorkloadIdentity(ctx context.Context, name string) error + DeleteAllWorkloadIdentities(ctx context.Context) error +} + +type workloadIdentityExecutor struct{} + +var _ executor[*workloadidentityv1pb.WorkloadIdentity, WorkloadIdentityReader] = workloadIdentityExecutor{} + +func (workloadIdentityExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*workloadidentityv1pb.WorkloadIdentity, error) { + var out []*workloadidentityv1pb.WorkloadIdentity + var nextToken string + for { + var page []*workloadidentityv1pb.WorkloadIdentity + var err error + + const defaultPageSize = 0 + page, nextToken, err = cache.Config.WorkloadIdentity.ListWorkloadIdentities(ctx, defaultPageSize, nextToken) + if err != nil { + return nil, trace.Wrap(err) + } + out = append(out, page...) + if nextToken == "" { + break + } + } + return out, nil +} + +func (workloadIdentityExecutor) upsert(ctx context.Context, cache *Cache, resource *workloadidentityv1pb.WorkloadIdentity) error { + _, err := cache.workloadIdentityCache.UpsertWorkloadIdentity(ctx, resource) + return trace.Wrap(err) +} + +func (workloadIdentityExecutor) deleteAll(ctx context.Context, cache *Cache) error { + return trace.Wrap(cache.workloadIdentityCache.DeleteAllWorkloadIdentities(ctx)) +} + +func (workloadIdentityExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { + return trace.Wrap(cache.workloadIdentityCache.DeleteWorkloadIdentity(ctx, resource.GetName())) +} + +func (workloadIdentityExecutor) isSingleton() bool { return false } + +func (workloadIdentityExecutor) getReader(cache *Cache, cacheOK bool) WorkloadIdentityReader { + if cacheOK { + return cache.workloadIdentityCache + } + return cache.Config.WorkloadIdentity +} + +// ListWorkloadIdentities returns a paginated list of WorkloadIdentity resources. +func (c *Cache) ListWorkloadIdentities(ctx context.Context, pageSize int, nextToken string) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListWorkloadIdentities") + defer span.End() + + rg, err := readCollectionCache(c, c.collections.workloadIdentity) + if err != nil { + return nil, "", trace.Wrap(err) + } + defer rg.Release() + out, nextKey, err := rg.reader.ListWorkloadIdentities(ctx, pageSize, nextToken) + return out, nextKey, trace.Wrap(err) +} + +// GetWorkloadIdentity returns a single WorkloadIdentity by name +func (c *Cache) GetWorkloadIdentity(ctx context.Context, name string) (*workloadidentityv1pb.WorkloadIdentity, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetWorkloadIdentity") + defer span.End() + + rg, err := readCollectionCache(c, c.collections.workloadIdentity) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + out, err := rg.reader.GetWorkloadIdentity(ctx, name) + return out, trace.Wrap(err) +} diff --git a/lib/cache/resource_workload_identity_test.go b/lib/cache/resource_workload_identity_test.go new file mode 100644 index 0000000000000..da82d64fec27c --- /dev/null +++ b/lib/cache/resource_workload_identity_test.go @@ -0,0 +1,74 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +func newWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity { + return &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + } +} + +func TestWorkloadIdentity(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources153(t, p, testFuncs153[*workloadidentityv1pb.WorkloadIdentity]{ + newResource: func(s string) (*workloadidentityv1pb.WorkloadIdentity, error) { + return newWorkloadIdentity(s), nil + }, + + create: func(ctx context.Context, item *workloadidentityv1pb.WorkloadIdentity) error { + _, err := p.workloadIdentity.CreateWorkloadIdentity(ctx, item) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) { + items, _, err := p.workloadIdentity.ListWorkloadIdentities(ctx, 0, "") + return items, trace.Wrap(err) + }, + deleteAll: func(ctx context.Context) error { + return p.workloadIdentity.DeleteAllWorkloadIdentities(ctx) + }, + + cacheList: func(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentity, error) { + items, _, err := p.cache.ListWorkloadIdentities(ctx, 0, "") + return items, trace.Wrap(err) + }, + cacheGet: p.cache.GetWorkloadIdentity, + }) +} diff --git a/lib/service/service.go b/lib/service/service.go index a60a926a7c486..16276a69827fa 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2548,6 +2548,7 @@ func (process *TeleportProcess) newAccessCacheForServices(cfg accesspoint.Config cfg.WebSession = services.Identity.WebSessions() cfg.WebToken = services.Identity.WebTokens() cfg.WindowsDesktops = services.WindowsDesktops + cfg.WorkloadIdentity = services.WorkloadIdentities cfg.DynamicWindowsDesktops = services.DynamicWindowsDesktops cfg.AutoUpdateService = services.AutoUpdateService cfg.ProvisioningStates = services.ProvisioningStates diff --git a/lib/services/local/events.go b/lib/services/local/events.go index d0522bf7bd5f2..9931b80857500 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -252,6 +252,8 @@ func (e *EventsService) NewWatcher(ctx context.Context, watch types.Watch) (type parser = newPluginStaticCredentialsParser() case types.KindGitServer: parser = newGitServerParser() + case types.KindWorkloadIdentity: + parser = newWorkloadIdentityParser() default: if watch.AllowPartialSuccess { continue @@ -3179,6 +3181,46 @@ func (p *spiffeFederationParser) parse(event backend.Event) (types.Resource, err } } +func newWorkloadIdentityParser() *workloadIdentityParser { + return &workloadIdentityParser{ + baseParser: newBaseParser(backend.NewKey(workloadIdentityPrefix)), + } +} + +type workloadIdentityParser struct { + baseParser +} + +func (p *workloadIdentityParser) parse(event backend.Event) (types.Resource, error) { + switch event.Type { + case types.OpDelete: + name := event.Item.Key.TrimPrefix(backend.NewKey(workloadIdentityPrefix)).String() + if name == "" { + return nil, trace.NotFound("failed parsing %v", event.Item.Key.String()) + } + + return &types.ResourceHeader{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: types.Metadata{ + Name: strings.TrimPrefix(name, backend.SeparatorString), + Namespace: apidefaults.Namespace, + }, + }, nil + case types.OpPut: + resource, err := services.UnmarshalWorkloadIdentity( + event.Item.Value, + services.WithExpires(event.Item.Expires), + services.WithRevision(event.Item.Revision)) + if err != nil { + return nil, trace.Wrap(err, "unmarshalling resource from event") + } + return types.Resource153ToLegacy(resource), nil + default: + return nil, trace.BadParameter("event %v is not supported", event.Type) + } +} + func newProvisioningStateParser() *provisioningStateParser { return &provisioningStateParser{ baseParser: newBaseParser(backend.NewKey(provisioningStatePrefix)), diff --git a/lib/services/local/workload_identity.go b/lib/services/local/workload_identity.go new file mode 100644 index 0000000000000..e0504e989cbe8 --- /dev/null +++ b/lib/services/local/workload_identity.go @@ -0,0 +1,118 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package local + +import ( + "context" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local/generic" +) + +const ( + workloadIdentityPrefix = "workload_identity" +) + +// WorkloadIdentityService exposes backend functionality for storing +// WorkloadIdentity resources +type WorkloadIdentityService struct { + service *generic.ServiceWrapper[*workloadidentityv1pb.WorkloadIdentity] +} + +// NewWorkloadIdentityService creates a new WorkloadIdentityService +func NewWorkloadIdentityService(b backend.Backend) (*WorkloadIdentityService, error) { + service, err := generic.NewServiceWrapper( + generic.ServiceWrapperConfig[*workloadidentityv1pb.WorkloadIdentity]{ + Backend: b, + ResourceKind: types.KindWorkloadIdentity, + BackendPrefix: backend.NewKey(workloadIdentityPrefix), + MarshalFunc: services.MarshalWorkloadIdentity, + UnmarshalFunc: services.UnmarshalWorkloadIdentity, + ValidateFunc: services.ValidateWorkloadIdentity, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return &WorkloadIdentityService{ + service: service, + }, nil +} + +// CreateWorkloadIdentity inserts a new WorkloadIdentity into the backend. +func (b *WorkloadIdentityService) CreateWorkloadIdentity( + ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + created, err := b.service.CreateResource(ctx, resource) + return created, trace.Wrap(err) +} + +// GetWorkloadIdentity retrieves a specific WorkloadIdentity given a name +func (b *WorkloadIdentityService) GetWorkloadIdentity( + ctx context.Context, name string, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + resource, err := b.service.GetResource(ctx, name) + return resource, trace.Wrap(err) +} + +// ListWorkloadIdentities lists all WorkloadIdentities using a given page size +// and last key. +func (b *WorkloadIdentityService) ListWorkloadIdentities( + ctx context.Context, pageSize int, currentToken string, +) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) { + r, nextToken, err := b.service.ListResources(ctx, pageSize, currentToken) + return r, nextToken, trace.Wrap(err) +} + +// DeleteWorkloadIdentity deletes a specific WorkloadIdentity. +func (b *WorkloadIdentityService) DeleteWorkloadIdentity( + ctx context.Context, name string, +) error { + return trace.Wrap(b.service.DeleteResource(ctx, name)) +} + +// DeleteAllWorkloadIdentities deletes all SPIFFE resources, this is typically +// only meant to be used by the cache. +func (b *WorkloadIdentityService) DeleteAllWorkloadIdentities( + ctx context.Context, +) error { + return trace.Wrap(b.service.DeleteAllResources(ctx)) +} + +// UpsertWorkloadIdentity upserts a WorkloadIdentitys. Prefer using +// CreateWorkloadIdentity. This is only designed for usage by the cache. +func (b *WorkloadIdentityService) UpsertWorkloadIdentity( + ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + upserted, err := b.service.UpsertResource(ctx, resource) + return upserted, trace.Wrap(err) +} + +// UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must +// already exist, and, condition update semantics are used - e.g the submitted +// resource must have a revision matching the revision of the resource in the +// backend. +func (b *WorkloadIdentityService) UpdateWorkloadIdentity( + ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentity, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + updated, err := b.service.ConditionalUpdateResource(ctx, resource) + return updated, trace.Wrap(err) +} diff --git a/lib/services/local/workload_identity_test.go b/lib/services/local/workload_identity_test.go new file mode 100644 index 0000000000000..acba05d9c8e4a --- /dev/null +++ b/lib/services/local/workload_identity_test.go @@ -0,0 +1,355 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package local + +import ( + "context" + "fmt" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/memory" +) + +func setupWorkloadIdentityServiceTest( + t *testing.T, +) (context.Context, *WorkloadIdentityService) { + t.Parallel() + ctx := context.Background() + clock := clockwork.NewFakeClock() + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + service, err := NewWorkloadIdentityService(backend.NewSanitizer(mem)) + require.NoError(t, err) + return ctx, service +} + +func newValidWorkloadIdentity(name string) *workloadidentityv1pb.WorkloadIdentity { + return &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/test", + }, + }, + } +} + +func TestWorkloadIdentityService_CreateWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + want := newValidWorkloadIdentity("example") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + require.Empty(t, cmp.Diff( + want, + got, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + }) + t.Run("validation occurs", func(t *testing.T) { + out, err := service.CreateWorkloadIdentity(ctx, newValidWorkloadIdentity("")) + require.ErrorContains(t, err, "metadata.name: is required") + require.Nil(t, out) + }) + t.Run("no upsert", func(t *testing.T) { + res := newValidWorkloadIdentity("duplicate") + _, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + _, err = service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(res).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.Error(t, err) + require.True(t, trace.IsAlreadyExists(err)) + }) +} + +func TestWorkloadIdentityService_UpsertWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + want := newValidWorkloadIdentity("example") + got, err := service.UpsertWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + require.Empty(t, cmp.Diff( + want, + got, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + + // Ensure we can upsert over an existing resource + _, err = service.UpsertWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + }) + t.Run("validation occurs", func(t *testing.T) { + out, err := service.UpdateWorkloadIdentity(ctx, newValidWorkloadIdentity("")) + require.ErrorContains(t, err, "metadata.name: is required") + require.Nil(t, out) + }) +} + +func TestWorkloadIdentityService_ListWorkloadIdentities(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + // Create entities to list + createdObjects := []*workloadidentityv1pb.WorkloadIdentity{} + // Create 49 entities to test an incomplete page at the end. + for i := 0; i < 49; i++ { + created, err := service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity(fmt.Sprintf("%d", i)), + ) + require.NoError(t, err) + createdObjects = append(createdObjects, created) + } + t.Run("default page size", func(t *testing.T) { + page, nextToken, err := service.ListWorkloadIdentities(ctx, 0, "") + require.NoError(t, err) + require.Len(t, page, 49) + require.Empty(t, nextToken) + + // Expect that we get all the things we have created + for _, created := range createdObjects { + slices.ContainsFunc(page, func(resource *workloadidentityv1pb.WorkloadIdentity) bool { + return proto.Equal(created, resource) + }) + } + }) + t.Run("pagination", func(t *testing.T) { + fetched := []*workloadidentityv1pb.WorkloadIdentity{} + token := "" + iterations := 0 + for { + iterations++ + page, nextToken, err := service.ListWorkloadIdentities(ctx, 10, token) + require.NoError(t, err) + fetched = append(fetched, page...) + if nextToken == "" { + break + } + token = nextToken + } + require.Equal(t, 5, iterations) + + require.Len(t, fetched, 49) + // Expect that we get all the things we have created + for _, created := range createdObjects { + slices.ContainsFunc(fetched, func(resource *workloadidentityv1pb.WorkloadIdentity) bool { + return proto.Equal(created, resource) + }) + } + }) +} + +func TestWorkloadIdentityService_GetWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + want := newValidWorkloadIdentity("example") + _, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(want).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + got, err := service.GetWorkloadIdentity(ctx, "example") + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + require.Empty(t, cmp.Diff( + want, + got, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + }) + t.Run("not found", func(t *testing.T) { + _, err := service.GetWorkloadIdentity(ctx, "not-found") + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + }) +} + +func TestWorkloadIdentityService_DeleteWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + _, err := service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity("example"), + ) + require.NoError(t, err) + + _, err = service.GetWorkloadIdentity(ctx, "example") + require.NoError(t, err) + + err = service.DeleteWorkloadIdentity(ctx, "example") + require.NoError(t, err) + + _, err = service.GetWorkloadIdentity(ctx, "example") + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + }) + t.Run("not found", func(t *testing.T) { + err := service.DeleteWorkloadIdentity(ctx, "foo.example.com") + require.Error(t, err) + require.True(t, trace.IsNotFound(err)) + }) +} + +func TestWorkloadIdentityService_DeleteAllWorkloadIdentities(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + _, err := service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity("1"), + ) + require.NoError(t, err) + _, err = service.CreateWorkloadIdentity( + ctx, + newValidWorkloadIdentity("2"), + ) + require.NoError(t, err) + + page, _, err := service.ListWorkloadIdentities(ctx, 0, "") + require.NoError(t, err) + require.Len(t, page, 2) + + err = service.DeleteAllWorkloadIdentities(ctx) + require.NoError(t, err) + + page, _, err = service.ListWorkloadIdentities(ctx, 0, "") + require.NoError(t, err) + require.Empty(t, page) +} + +func TestWorkloadIdentityService_UpdateWorkloadIdentity(t *testing.T) { + ctx, service := setupWorkloadIdentityServiceTest(t) + + t.Run("ok", func(t *testing.T) { + // Create first to support updating + toCreate := newValidWorkloadIdentity("example") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + got.Spec.Spiffe.Id = "/changed" + got2, err := service.UpdateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got2.Metadata.Revision) + require.Empty(t, cmp.Diff( + got, + got2, + protocmp.Transform(), + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + )) + }) + t.Run("validation occurs", func(t *testing.T) { + // Create first to support updating + toCreate := newValidWorkloadIdentity("example2") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + require.NotEmpty(t, got.Metadata.Revision) + got.Spec.Spiffe.Id = "" + got2, err := service.UpdateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.ErrorContains(t, err, "spec.spiffe.id: is required") + require.Nil(t, got2) + }) + t.Run("cond update blocks", func(t *testing.T) { + toCreate := newValidWorkloadIdentity("example4") + got, err := service.CreateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toCreate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + // We'll now update it twice, but on the second update, we will use the + // revision from the creation not the second update. + _, err = service.UpdateWorkloadIdentity( + ctx, + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.NoError(t, err) + _, err = service.UpdateWorkloadIdentity( + ctx, + proto.Clone(got).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.ErrorIs(t, err, backend.ErrIncorrectRevision) + }) + t.Run("no upsert", func(t *testing.T) { + toUpdate := newValidWorkloadIdentity("example3") + _, err := service.UpdateWorkloadIdentity( + ctx, + // Clone to avoid Marshaling modifying want + proto.Clone(toUpdate).(*workloadidentityv1pb.WorkloadIdentity), + ) + require.Error(t, err) + }) +} diff --git a/lib/services/workload_identity.go b/lib/services/workload_identity.go new file mode 100644 index 0000000000000..89b87ba0d2473 --- /dev/null +++ b/lib/services/workload_identity.go @@ -0,0 +1,122 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package services + +import ( + "context" + "strings" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +// WorkloadIdentities is an interface over the WorkloadIdentities service. This +// interface may also be implemented by a client to allow remote and local +// consumers to access the resource in a similar way. +type WorkloadIdentities interface { + // GetWorkloadIdentity gets a SPIFFE Federation by name. + GetWorkloadIdentity( + ctx context.Context, name string, + ) (*workloadidentityv1pb.WorkloadIdentity, error) + // ListWorkloadIdentities lists all WorkloadIdentities using Google style + // pagination. + ListWorkloadIdentities( + ctx context.Context, pageSize int, lastToken string, + ) ([]*workloadidentityv1pb.WorkloadIdentity, string, error) + // CreateWorkloadIdentity creates a new WorkloadIdentity. + CreateWorkloadIdentity( + ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity, + ) (*workloadidentityv1pb.WorkloadIdentity, error) + // DeleteWorkloadIdentity deletes a SPIFFE Federation by name. + DeleteWorkloadIdentity(ctx context.Context, name string) error + // UpdateWorkloadIdentity updates a specific WorkloadIdentity. The resource must + // already exist, and, condition update semantics are used - e.g the submitted + // resource must have a revision matching the revision of the resource in the + // backend. + UpdateWorkloadIdentity( + ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity, + ) (*workloadidentityv1pb.WorkloadIdentity, error) + // UpsertWorkloadIdentity creates or updates a WorkloadIdentity. + UpsertWorkloadIdentity( + ctx context.Context, workloadIdentity *workloadidentityv1pb.WorkloadIdentity, + ) (*workloadidentityv1pb.WorkloadIdentity, error) +} + +// MarshalWorkloadIdentity marshals the WorkloadIdentity object into a JSON byte +// array. +func MarshalWorkloadIdentity( + object *workloadidentityv1pb.WorkloadIdentity, opts ...MarshalOption, +) ([]byte, error) { + return MarshalProtoResource(object, opts...) +} + +// UnmarshalWorkloadIdentity unmarshals the WorkloadIdentity object from a +// JSON byte array. +func UnmarshalWorkloadIdentity( + data []byte, opts ...MarshalOption, +) (*workloadidentityv1pb.WorkloadIdentity, error) { + return UnmarshalProtoResource[*workloadidentityv1pb.WorkloadIdentity](data, opts...) +} + +// ValidateWorkloadIdentity validates the WorkloadIdentity object. This is +// performed prior to writing to the backend. +func ValidateWorkloadIdentity(s *workloadidentityv1pb.WorkloadIdentity) error { + switch { + case s == nil: + return trace.BadParameter("object cannot be nil") + case s.Version != types.V1: + return trace.BadParameter("version: only %q is supported", types.V1) + case s.Kind != types.KindWorkloadIdentity: + return trace.BadParameter("kind: must be %q", types.KindWorkloadIdentity) + case s.Metadata == nil: + return trace.BadParameter("metadata: is required") + case s.Metadata.Name == "": + return trace.BadParameter("metadata.name: is required") + case s.Spec == nil: + return trace.BadParameter("spec: is required") + case s.Spec.Spiffe.Id == "": + return trace.BadParameter("spec.spiffe.id: is required") + case !strings.HasPrefix(s.Spec.Spiffe.Id, "/"): + return trace.BadParameter("spec.spiffe.id: must start with a /") + } + + for i, rule := range s.GetSpec().GetRules().GetAllow() { + if len(rule.Conditions) == 0 { + return trace.BadParameter("spec.rules.allow[%d].conditions: must be non-empty", i) + } + for j, condition := range rule.Conditions { + if condition.Attribute == "" { + return trace.BadParameter("spec.rules.allow[%d].conditions[%d].attribute: must be non-empty", i, j) + } + // Ensure exactly one operator is set. + operatorsSet := 0 + if condition.Equals != "" { + operatorsSet++ + } + if operatorsSet == 0 || operatorsSet > 1 { + return trace.BadParameter( + "spec.rules.allow[%d].conditions[%d]: exactly one operator must be specified, found %d", + i, j, operatorsSet, + ) + } + } + } + + return nil +} diff --git a/lib/services/workload_identity_test.go b/lib/services/workload_identity_test.go new file mode 100644 index 0000000000000..429612ed48555 --- /dev/null +++ b/lib/services/workload_identity_test.go @@ -0,0 +1,231 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package services + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" +) + +func TestWorkloadIdentityMarshaling(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in *workloadidentityv1pb.WorkloadIdentity + }{ + { + name: "normal", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotBytes, err := MarshalWorkloadIdentity(tc.in) + require.NoError(t, err) + // Test that unmarshaling gives us the same object + got, err := UnmarshalWorkloadIdentity(gotBytes) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tc.in, got, protocmp.Transform())) + }) + } +} + +func TestValidateWorkloadIdentity(t *testing.T) { + t.Parallel() + + var errContains = func(contains string) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, msgAndArgs ...interface{}) { + require.ErrorContains(t, err, contains, msgAndArgs...) + } + } + + testCases := []struct { + name string + in *workloadidentityv1pb.WorkloadIdentity + requireErr require.ErrorAssertionFunc + }{ + { + name: "success - full", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "example", + Equals: "foo", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: require.NoError, + }, + { + name: "success - minimal", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: require.NoError, + }, + { + name: "missing name", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{}, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: errContains("metadata.name: is required"), + }, + { + name: "missing spiffe id", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{}, + }, + }, + requireErr: errContains("spec.spiffe.id: is required"), + }, + { + name: "spiffe id must have leading /", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "example", + }, + }, + }, + requireErr: errContains("spec.spiffe.id: must start with a /"), + }, + { + name: "missing attribute", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "", + Equals: "foo", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: errContains("spec.rules.allow[0].conditions[0].attribute: must be non-empty"), + }, + { + name: "missing operator", + in: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "example", + }, + }, + }, + }, + }, + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/example", + }, + }, + }, + requireErr: errContains("spec.rules.allow[0].conditions[0]: exactly one operator must be specified, found 0"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateWorkloadIdentity(tc.in) + tc.requireErr(t, err) + }) + } +}