Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor lib/tbot/spiffe to use WorkloadAttrs proto #50833

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions lib/tbot/service_spiffe_workload_api.go
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ import (

"github.com/gravitational/teleport"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/observability/metrics"
@@ -227,13 +228,27 @@ func (s *SPIFFEWorkloadAPIService) Run(ctx context.Context) error {
)
workloadpb.RegisterSpiffeWorkloadAPIServer(srv, s)
sdsHandler := &spiffeSDSHandler{
log: s.log,
cfg: s.cfg,
botCfg: s.botCfg,

trustBundleCache: s.trustBundleCache,
clientAuthenticator: s.authenticateClient,
svidFetcher: s.fetchX509SVIDs,
log: s.log,
botCfg: s.botCfg,
trustBundleCache: s.trustBundleCache,
clientAuthenticator: func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
log, attrs, err := s.authenticateClient(ctx)
if err != nil {
return log, nil, trace.Wrap(err, "authenticating client")
}
fetchSVIDs := func(
ctx context.Context,
localBundle *spiffebundle.Bundle,
) ([]*workloadpb.X509SVID, error) {
return s.fetchX509SVIDs(
ctx,
log,
localBundle,
filterSVIDRequests(ctx, log, s.cfg.SVIDs, attrs),
)
}
return log, fetchSVIDs, nil
},
}
secretv3pb.RegisterSecretDiscoveryServiceServer(srv, sdsHandler)

@@ -373,7 +388,7 @@ func filterSVIDRequests(
ctx context.Context,
log *slog.Logger,
svidRequests []config.SVIDRequestWithRules,
att workloadattest.Attestation,
att *workloadidentityv1pb.WorkloadAttrs,
) []config.SVIDRequest {
var filtered []config.SVIDRequest
for _, req := range svidRequests {
@@ -413,67 +428,67 @@ func filterSVIDRequests(
"Evaluating rule against workload attestation",
)
if rule.Unix.UID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.UID != att.Unix.UID {
logMismatch("unix.uid", *rule.Unix.UID, att.Unix.UID)
if *rule.Unix.UID != int(att.GetUnix().GetUid()) {
logMismatch("unix.uid", *rule.Unix.UID, att.GetUnix().GetUid())
continue
}
// Rule field matched!
}
if rule.Unix.PID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.PID != att.Unix.PID {
logMismatch("unix.pid", *rule.Unix.PID, att.Unix.PID)
if *rule.Unix.PID != int(att.GetUnix().GetPid()) {
logMismatch("unix.pid", *rule.Unix.PID, att.GetUnix().GetPid())
continue
}
// Rule field matched!
}
if rule.Unix.GID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.GID != att.Unix.GID {
logMismatch("unix.gid", *rule.Unix.GID, att.Unix.GID)
if *rule.Unix.GID != int(att.GetUnix().GetGid()) {
logMismatch("unix.gid", *rule.Unix.GID, att.GetUnix().GetGid())
continue
}
// Rule field matched!
}
if rule.Kubernetes.Namespace != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.Namespace != att.Kubernetes.Namespace {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.Kubernetes.Namespace)
if rule.Kubernetes.Namespace != att.GetKubernetes().GetNamespace() {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.GetKubernetes().GetNamespace())
continue
}
// Rule field matched!
}
if rule.Kubernetes.PodName != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.PodName != att.Kubernetes.PodName {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.Kubernetes.PodName)
if rule.Kubernetes.PodName != att.GetKubernetes().GetPodName() {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.GetKubernetes().GetPodName())
continue
}
// Rule field matched!
}
if rule.Kubernetes.ServiceAccount != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.ServiceAccount != att.Kubernetes.ServiceAccount {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.Kubernetes.ServiceAccount)
if rule.Kubernetes.ServiceAccount != att.GetKubernetes().GetServiceAccount() {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.GetKubernetes().GetServiceAccount())
continue
}
// Rule field matched!
@@ -499,10 +514,10 @@ func filterSVIDRequests(

func (s *SPIFFEWorkloadAPIService) authenticateClient(
ctx context.Context,
) (*slog.Logger, workloadattest.Attestation, error) {
) (*slog.Logger, *workloadidentityv1pb.WorkloadAttrs, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, workloadattest.Attestation{}, trace.BadParameter("peer not found in context")
return nil, nil, trace.BadParameter("peer not found in context")
}
log := s.log

@@ -516,7 +531,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
// We expect Creds to be nil/unset if the client is connecting via TCP and
// therefore there is no workload attestation that can be completed.
if !ok || authInfo.Creds == nil {
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

// For a UDS, sometimes we are unable to determine the PID of the calling
@@ -528,7 +543,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
if authInfo.Creds.PID == 0 {
log.DebugContext(
ctx, "Failed to determine the PID of the calling workload. TBot may be running in a different process namespace to the workload. Workload attestation will not be completed.")
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

att, err := s.attestor.Attest(ctx, authInfo.Creds.PID)
@@ -541,10 +556,10 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
"error", err,
"pid", authInfo.Creds.PID,
)
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}
log = log.With(
"workload", slog.LogValuer(att),
"workload", att,
)

return log, att, nil
27 changes: 7 additions & 20 deletions lib/tbot/service_spiffe_workload_api_sds.go
Original file line number Diff line number Diff line change
@@ -40,7 +40,6 @@ import (

"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
)

@@ -63,23 +62,18 @@ type bundleSetGetter interface {
GetBundleSet(ctx context.Context) (*workloadidentity.BundleSet, error)
}

type svidFetcher func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error)

// spiffeSDSHandler implements an Envoy SDS API.
//
// This effectively replaces the Workload API for Envoy, but functions in a
// very similar way.
type spiffeSDSHandler struct {
log *slog.Logger
cfg *config.SPIFFEWorkloadAPIService
botCfg *config.BotConfig
trustBundleCache bundleSetGetter

clientAuthenticator func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error)
svidFetcher func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest,
) ([]*workloadpb.X509SVID, error)
clientAuthenticator func(ctx context.Context) (*slog.Logger, svidFetcher, error)
}

// FetchSecrets implements
@@ -97,7 +91,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err)
}

log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return nil, trace.Wrap(err, "authenticating client")
}
@@ -114,11 +108,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err, "getting trust bundle set")
}

// Filter SVIDs down to those accessible to this workload
svids, err := s.svidFetcher(
ctx,
log,
bundleSet.Local, filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds))
svids, err := fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return nil, trace.Wrap(err, "fetching X509 SVIDs")
}
@@ -174,7 +164,7 @@ func (s *spiffeSDSHandler) StreamSecrets(
srv secretv3pb.SecretDiscoveryService_StreamSecretsServer,
) error {
ctx := srv.Context()
log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return trace.Wrap(err, "authenticating client")
}
@@ -216,9 +206,6 @@ func (s *spiffeSDSHandler) StreamSecrets(
renewalTimer.Stop()
defer renewalTimer.Stop()

// Filter SVIDs down to those accessible to this workload
availableSVIDs := filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds)

// Track the last response and last request to allow us to handle ACK/NACK
// and versioning.
var (
@@ -311,7 +298,7 @@ func (s *spiffeSDSHandler) StreamSecrets(

// Fetch the SVIDs if necessary
if svids == nil {
svids, err = s.svidFetcher(ctx, log, bundleSet.Local, availableSVIDs)
svids, err = fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return trace.Wrap(err, "fetching X509 SVIDs")
}
91 changes: 16 additions & 75 deletions lib/tbot/service_spiffe_workload_api_sds_test.go
Original file line number Diff line number Diff line change
@@ -35,7 +35,6 @@ import (
discoveryv3pb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
secretv3pb "github.com/envoyproxy/go-control-plane/envoy/service/secret/v3"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/spiffe/go-spiffe/v2/bundle/spiffebundle"
workloadpb "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
"github.com/spiffe/go-spiffe/v2/spiffeid"
@@ -51,7 +50,6 @@ import (
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/testutils/golden"
"github.com/gravitational/teleport/tool/teleport/testenv"
@@ -80,14 +78,22 @@ func TestSDS_FetchSecrets(t *testing.T) {
ca, err := x509.ParseCertificate(b.Bytes)
require.NoError(t, err)

uid := 100
notUID := 200
clientAuthenticator := func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) {
return log, workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
Attested: true,
UID: uid,
},
clientAuthenticator := func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
return log, func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) {
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}, nil
}

@@ -105,72 +111,9 @@ func TestSDS_FetchSecrets(t *testing.T) {
},
},
}
svidFetcher := func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest) ([]*workloadpb.X509SVID, error) {
if len(svidRequests) != 2 {
return nil, trace.BadParameter("expected 2 svids requested")
}
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}
botConfig := &config.BotConfig{
RenewalInterval: time.Minute,
}
cfg := &config.SPIFFEWorkloadAPIService{
SVIDs: []config.SVIDRequestWithRules{
{
SVIDRequest: config.SVIDRequest{
Path: "/default",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/second",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/not-matching",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &notUID,
},
},
},
},
},
}

tests := []struct {
name string
@@ -231,12 +174,10 @@ func TestSDS_FetchSecrets(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
sds := &spiffeSDSHandler{
log: log,
cfg: cfg,
botCfg: botConfig,

trustBundleCache: mockBundleCache,
clientAuthenticator: clientAuthenticator,
svidFetcher: svidFetcher,
}

req := &discoveryv3pb.DiscoveryRequest{
Loading