Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] Refactor lib/tbot/spiffe to use WorkloadAttrs proto (#50833) #51061

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions lib/tbot/service_spiffe_workload_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (

"github.com/gravitational/teleport"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/observability/metrics"
Expand Down Expand Up @@ -227,13 +228,27 @@ func (s *SPIFFEWorkloadAPIService) Run(ctx context.Context) error {
)
workloadpb.RegisterSpiffeWorkloadAPIServer(srv, s)
sdsHandler := &spiffeSDSHandler{
log: s.log,
cfg: s.cfg,
botCfg: s.botCfg,

trustBundleCache: s.trustBundleCache,
clientAuthenticator: s.authenticateClient,
svidFetcher: s.fetchX509SVIDs,
log: s.log,
botCfg: s.botCfg,
trustBundleCache: s.trustBundleCache,
clientAuthenticator: func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
log, attrs, err := s.authenticateClient(ctx)
if err != nil {
return log, nil, trace.Wrap(err, "authenticating client")
}
fetchSVIDs := func(
ctx context.Context,
localBundle *spiffebundle.Bundle,
) ([]*workloadpb.X509SVID, error) {
return s.fetchX509SVIDs(
ctx,
log,
localBundle,
filterSVIDRequests(ctx, log, s.cfg.SVIDs, attrs),
)
}
return log, fetchSVIDs, nil
},
}
secretv3pb.RegisterSecretDiscoveryServiceServer(srv, sdsHandler)

Expand Down Expand Up @@ -373,7 +388,7 @@ func filterSVIDRequests(
ctx context.Context,
log *slog.Logger,
svidRequests []config.SVIDRequestWithRules,
att workloadattest.Attestation,
att *workloadidentityv1pb.WorkloadAttrs,
) []config.SVIDRequest {
var filtered []config.SVIDRequest
for _, req := range svidRequests {
Expand Down Expand Up @@ -413,67 +428,67 @@ func filterSVIDRequests(
"Evaluating rule against workload attestation",
)
if rule.Unix.UID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.UID != att.Unix.UID {
logMismatch("unix.uid", *rule.Unix.UID, att.Unix.UID)
if *rule.Unix.UID != int(att.GetUnix().GetUid()) {
logMismatch("unix.uid", *rule.Unix.UID, att.GetUnix().GetUid())
continue
}
// Rule field matched!
}
if rule.Unix.PID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.PID != att.Unix.PID {
logMismatch("unix.pid", *rule.Unix.PID, att.Unix.PID)
if *rule.Unix.PID != int(att.GetUnix().GetPid()) {
logMismatch("unix.pid", *rule.Unix.PID, att.GetUnix().GetPid())
continue
}
// Rule field matched!
}
if rule.Unix.GID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.GID != att.Unix.GID {
logMismatch("unix.gid", *rule.Unix.GID, att.Unix.GID)
if *rule.Unix.GID != int(att.GetUnix().GetGid()) {
logMismatch("unix.gid", *rule.Unix.GID, att.GetUnix().GetGid())
continue
}
// Rule field matched!
}
if rule.Kubernetes.Namespace != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.Namespace != att.Kubernetes.Namespace {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.Kubernetes.Namespace)
if rule.Kubernetes.Namespace != att.GetKubernetes().GetNamespace() {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.GetKubernetes().GetNamespace())
continue
}
// Rule field matched!
}
if rule.Kubernetes.PodName != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.PodName != att.Kubernetes.PodName {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.Kubernetes.PodName)
if rule.Kubernetes.PodName != att.GetKubernetes().GetPodName() {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.GetKubernetes().GetPodName())
continue
}
// Rule field matched!
}
if rule.Kubernetes.ServiceAccount != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.ServiceAccount != att.Kubernetes.ServiceAccount {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.Kubernetes.ServiceAccount)
if rule.Kubernetes.ServiceAccount != att.GetKubernetes().GetServiceAccount() {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.GetKubernetes().GetServiceAccount())
continue
}
// Rule field matched!
Expand All @@ -499,10 +514,10 @@ func filterSVIDRequests(

func (s *SPIFFEWorkloadAPIService) authenticateClient(
ctx context.Context,
) (*slog.Logger, workloadattest.Attestation, error) {
) (*slog.Logger, *workloadidentityv1pb.WorkloadAttrs, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, workloadattest.Attestation{}, trace.BadParameter("peer not found in context")
return nil, nil, trace.BadParameter("peer not found in context")
}
log := s.log

Expand All @@ -516,7 +531,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
// We expect Creds to be nil/unset if the client is connecting via TCP and
// therefore there is no workload attestation that can be completed.
if !ok || authInfo.Creds == nil {
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

// For a UDS, sometimes we are unable to determine the PID of the calling
Expand All @@ -528,7 +543,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
if authInfo.Creds.PID == 0 {
log.DebugContext(
ctx, "Failed to determine the PID of the calling workload. TBot may be running in a different process namespace to the workload. Workload attestation will not be completed.")
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

att, err := s.attestor.Attest(ctx, authInfo.Creds.PID)
Expand All @@ -541,10 +556,10 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
"error", err,
"pid", authInfo.Creds.PID,
)
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}
log = log.With(
"workload", slog.LogValuer(att),
"workload", att,
)

return log, att, nil
Expand Down
27 changes: 7 additions & 20 deletions lib/tbot/service_spiffe_workload_api_sds.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (

"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
)

Expand All @@ -63,23 +62,18 @@ type bundleSetGetter interface {
GetBundleSet(ctx context.Context) (*workloadidentity.BundleSet, error)
}

type svidFetcher func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error)

// spiffeSDSHandler implements an Envoy SDS API.
//
// This effectively replaces the Workload API for Envoy, but functions in a
// very similar way.
type spiffeSDSHandler struct {
log *slog.Logger
cfg *config.SPIFFEWorkloadAPIService
botCfg *config.BotConfig
trustBundleCache bundleSetGetter

clientAuthenticator func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error)
svidFetcher func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest,
) ([]*workloadpb.X509SVID, error)
clientAuthenticator func(ctx context.Context) (*slog.Logger, svidFetcher, error)
}

// FetchSecrets implements
Expand All @@ -97,7 +91,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err)
}

log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return nil, trace.Wrap(err, "authenticating client")
}
Expand All @@ -114,11 +108,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err, "getting trust bundle set")
}

// Filter SVIDs down to those accessible to this workload
svids, err := s.svidFetcher(
ctx,
log,
bundleSet.Local, filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds))
svids, err := fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return nil, trace.Wrap(err, "fetching X509 SVIDs")
}
Expand Down Expand Up @@ -174,7 +164,7 @@ func (s *spiffeSDSHandler) StreamSecrets(
srv secretv3pb.SecretDiscoveryService_StreamSecretsServer,
) error {
ctx := srv.Context()
log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return trace.Wrap(err, "authenticating client")
}
Expand Down Expand Up @@ -216,9 +206,6 @@ func (s *spiffeSDSHandler) StreamSecrets(
renewalTimer.Stop()
defer renewalTimer.Stop()

// Filter SVIDs down to those accessible to this workload
availableSVIDs := filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds)

// Track the last response and last request to allow us to handle ACK/NACK
// and versioning.
var (
Expand Down Expand Up @@ -311,7 +298,7 @@ func (s *spiffeSDSHandler) StreamSecrets(

// Fetch the SVIDs if necessary
if svids == nil {
svids, err = s.svidFetcher(ctx, log, bundleSet.Local, availableSVIDs)
svids, err = fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return trace.Wrap(err, "fetching X509 SVIDs")
}
Expand Down
91 changes: 16 additions & 75 deletions lib/tbot/service_spiffe_workload_api_sds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
discoveryv3pb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
secretv3pb "github.com/envoyproxy/go-control-plane/envoy/service/secret/v3"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/spiffe/go-spiffe/v2/bundle/spiffebundle"
workloadpb "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
"github.com/spiffe/go-spiffe/v2/spiffeid"
Expand All @@ -51,7 +50,6 @@ import (
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/testutils/golden"
"github.com/gravitational/teleport/tool/teleport/testenv"
Expand Down Expand Up @@ -80,14 +78,22 @@ func TestSDS_FetchSecrets(t *testing.T) {
ca, err := x509.ParseCertificate(b.Bytes)
require.NoError(t, err)

uid := 100
notUID := 200
clientAuthenticator := func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) {
return log, workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
Attested: true,
UID: uid,
},
clientAuthenticator := func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
return log, func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) {
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}, nil
}

Expand All @@ -105,72 +111,9 @@ func TestSDS_FetchSecrets(t *testing.T) {
},
},
}
svidFetcher := func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest) ([]*workloadpb.X509SVID, error) {
if len(svidRequests) != 2 {
return nil, trace.BadParameter("expected 2 svids requested")
}
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}
botConfig := &config.BotConfig{
RenewalInterval: time.Minute,
}
cfg := &config.SPIFFEWorkloadAPIService{
SVIDs: []config.SVIDRequestWithRules{
{
SVIDRequest: config.SVIDRequest{
Path: "/default",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/second",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/not-matching",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &notUID,
},
},
},
},
},
}

tests := []struct {
name string
Expand Down Expand Up @@ -231,12 +174,10 @@ func TestSDS_FetchSecrets(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
sds := &spiffeSDSHandler{
log: log,
cfg: cfg,
botCfg: botConfig,

trustBundleCache: mockBundleCache,
clientAuthenticator: clientAuthenticator,
svidFetcher: svidFetcher,
}

req := &discoveryv3pb.DiscoveryRequest{
Expand Down
Loading
Loading