Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor lib/tbot/spiffe to use WorkloadAttrs proto #50833

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions lib/tbot/service_spiffe_workload_api.go
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ import (

"github.com/gravitational/teleport"
machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/observability/metrics"
@@ -227,13 +228,27 @@ func (s *SPIFFEWorkloadAPIService) Run(ctx context.Context) error {
)
workloadpb.RegisterSpiffeWorkloadAPIServer(srv, s)
sdsHandler := &spiffeSDSHandler{
log: s.log,
cfg: s.cfg,
botCfg: s.botCfg,

trustBundleCache: s.trustBundleCache,
clientAuthenticator: s.authenticateClient,
svidFetcher: s.fetchX509SVIDs,
log: s.log,
botCfg: s.botCfg,
trustBundleCache: s.trustBundleCache,
clientAuthenticator: func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
log, attrs, err := s.authenticateClient(ctx)
if err != nil {
return log, nil, trace.Wrap(err, "authenticating client")
}
fetchSVIDs := func(
ctx context.Context,
localBundle *spiffebundle.Bundle,
) ([]*workloadpb.X509SVID, error) {
return s.fetchX509SVIDs(
ctx,
log,
localBundle,
filterSVIDRequests(ctx, log, s.cfg.SVIDs, attrs),
)
}
return log, fetchSVIDs, nil
},
}
secretv3pb.RegisterSecretDiscoveryServiceServer(srv, sdsHandler)

@@ -373,7 +388,7 @@ func filterSVIDRequests(
ctx context.Context,
log *slog.Logger,
svidRequests []config.SVIDRequestWithRules,
att workloadattest.Attestation,
att *workloadidentityv1pb.WorkloadAttrs,
) []config.SVIDRequest {
var filtered []config.SVIDRequest
for _, req := range svidRequests {
@@ -413,67 +428,67 @@ func filterSVIDRequests(
"Evaluating rule against workload attestation",
)
if rule.Unix.UID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.UID != att.Unix.UID {
logMismatch("unix.uid", *rule.Unix.UID, att.Unix.UID)
if *rule.Unix.UID != int(att.GetUnix().GetUid()) {
logMismatch("unix.uid", *rule.Unix.UID, att.GetUnix().GetUid())
continue
}
// Rule field matched!
}
if rule.Unix.PID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.PID != att.Unix.PID {
logMismatch("unix.pid", *rule.Unix.PID, att.Unix.PID)
if *rule.Unix.PID != int(att.GetUnix().GetPid()) {
logMismatch("unix.pid", *rule.Unix.PID, att.GetUnix().GetPid())
continue
}
// Rule field matched!
}
if rule.Unix.GID != nil {
if !att.Unix.Attested {
if !att.GetUnix().GetAttested() {
logNotAttested("unix")
continue
}
if *rule.Unix.GID != att.Unix.GID {
logMismatch("unix.gid", *rule.Unix.GID, att.Unix.GID)
if *rule.Unix.GID != int(att.GetUnix().GetGid()) {
logMismatch("unix.gid", *rule.Unix.GID, att.GetUnix().GetGid())
continue
}
// Rule field matched!
}
if rule.Kubernetes.Namespace != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.Namespace != att.Kubernetes.Namespace {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.Kubernetes.Namespace)
if rule.Kubernetes.Namespace != att.GetKubernetes().GetNamespace() {
logMismatch("kubernetes.namespace", rule.Kubernetes.Namespace, att.GetKubernetes().GetNamespace())
continue
}
// Rule field matched!
}
if rule.Kubernetes.PodName != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.PodName != att.Kubernetes.PodName {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.Kubernetes.PodName)
if rule.Kubernetes.PodName != att.GetKubernetes().GetPodName() {
logMismatch("kubernetes.pod_name", rule.Kubernetes.PodName, att.GetKubernetes().GetPodName())
continue
}
// Rule field matched!
}
if rule.Kubernetes.ServiceAccount != "" {
if !att.Kubernetes.Attested {
if !att.GetKubernetes().GetAttested() {
logNotAttested("kubernetes")
continue
}
if rule.Kubernetes.ServiceAccount != att.Kubernetes.ServiceAccount {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.Kubernetes.ServiceAccount)
if rule.Kubernetes.ServiceAccount != att.GetKubernetes().GetServiceAccount() {
logMismatch("kubernetes.service_account", rule.Kubernetes.ServiceAccount, att.GetKubernetes().GetServiceAccount())
continue
}
// Rule field matched!
@@ -499,10 +514,10 @@ func filterSVIDRequests(

func (s *SPIFFEWorkloadAPIService) authenticateClient(
ctx context.Context,
) (*slog.Logger, workloadattest.Attestation, error) {
) (*slog.Logger, *workloadidentityv1pb.WorkloadAttrs, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, workloadattest.Attestation{}, trace.BadParameter("peer not found in context")
return nil, nil, trace.BadParameter("peer not found in context")
}
log := s.log

@@ -516,7 +531,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
// We expect Creds to be nil/unset if the client is connecting via TCP and
// therefore there is no workload attestation that can be completed.
if !ok || authInfo.Creds == nil {
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

// For a UDS, sometimes we are unable to determine the PID of the calling
@@ -528,7 +543,7 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
if authInfo.Creds.PID == 0 {
log.DebugContext(
ctx, "Failed to determine the PID of the calling workload. TBot may be running in a different process namespace to the workload. Workload attestation will not be completed.")
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}

att, err := s.attestor.Attest(ctx, authInfo.Creds.PID)
@@ -541,10 +556,10 @@ func (s *SPIFFEWorkloadAPIService) authenticateClient(
"error", err,
"pid", authInfo.Creds.PID,
)
return log, workloadattest.Attestation{}, nil
return log, nil, nil
}
log = log.With(
"workload", slog.LogValuer(att),
"workload", att,
)

return log, att, nil
27 changes: 7 additions & 20 deletions lib/tbot/service_spiffe_workload_api_sds.go
Original file line number Diff line number Diff line change
@@ -40,7 +40,6 @@ import (

"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
)

@@ -63,23 +62,18 @@ type bundleSetGetter interface {
GetBundleSet(ctx context.Context) (*workloadidentity.BundleSet, error)
}

type svidFetcher func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error)

// spiffeSDSHandler implements an Envoy SDS API.
//
// This effectively replaces the Workload API for Envoy, but functions in a
// very similar way.
type spiffeSDSHandler struct {
log *slog.Logger
cfg *config.SPIFFEWorkloadAPIService
botCfg *config.BotConfig
trustBundleCache bundleSetGetter

clientAuthenticator func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error)
svidFetcher func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest,
) ([]*workloadpb.X509SVID, error)
clientAuthenticator func(ctx context.Context) (*slog.Logger, svidFetcher, error)
}

// FetchSecrets implements
@@ -97,7 +91,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err)
}

log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return nil, trace.Wrap(err, "authenticating client")
}
@@ -114,11 +108,7 @@ func (s *spiffeSDSHandler) FetchSecrets(
return nil, trace.Wrap(err, "getting trust bundle set")
}

// Filter SVIDs down to those accessible to this workload
svids, err := s.svidFetcher(
ctx,
log,
bundleSet.Local, filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds))
svids, err := fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return nil, trace.Wrap(err, "fetching X509 SVIDs")
}
@@ -174,7 +164,7 @@ func (s *spiffeSDSHandler) StreamSecrets(
srv secretv3pb.SecretDiscoveryService_StreamSecretsServer,
) error {
ctx := srv.Context()
log, creds, err := s.clientAuthenticator(ctx)
log, fetchSVIDs, err := s.clientAuthenticator(ctx)
if err != nil {
return trace.Wrap(err, "authenticating client")
}
@@ -216,9 +206,6 @@ func (s *spiffeSDSHandler) StreamSecrets(
renewalTimer.Stop()
defer renewalTimer.Stop()

// Filter SVIDs down to those accessible to this workload
availableSVIDs := filterSVIDRequests(ctx, log, s.cfg.SVIDs, creds)

// Track the last response and last request to allow us to handle ACK/NACK
// and versioning.
var (
@@ -311,7 +298,7 @@ func (s *spiffeSDSHandler) StreamSecrets(

// Fetch the SVIDs if necessary
if svids == nil {
svids, err = s.svidFetcher(ctx, log, bundleSet.Local, availableSVIDs)
svids, err = fetchSVIDs(ctx, bundleSet.Local)
if err != nil {
return trace.Wrap(err, "fetching X509 SVIDs")
}
91 changes: 16 additions & 75 deletions lib/tbot/service_spiffe_workload_api_sds_test.go
Original file line number Diff line number Diff line change
@@ -35,7 +35,6 @@ import (
discoveryv3pb "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
secretv3pb "github.com/envoyproxy/go-control-plane/envoy/service/secret/v3"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/spiffe/go-spiffe/v2/bundle/spiffebundle"
workloadpb "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
"github.com/spiffe/go-spiffe/v2/spiffeid"
@@ -51,7 +50,6 @@ import (
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/testutils/golden"
"github.com/gravitational/teleport/tool/teleport/testenv"
@@ -80,14 +78,22 @@ func TestSDS_FetchSecrets(t *testing.T) {
ca, err := x509.ParseCertificate(b.Bytes)
require.NoError(t, err)

uid := 100
notUID := 200
clientAuthenticator := func(ctx context.Context) (*slog.Logger, workloadattest.Attestation, error) {
return log, workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
Attested: true,
UID: uid,
},
clientAuthenticator := func(ctx context.Context) (*slog.Logger, svidFetcher, error) {
return log, func(ctx context.Context, localBundle *spiffebundle.Bundle) ([]*workloadpb.X509SVID, error) {
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}, nil
}

@@ -105,72 +111,9 @@ func TestSDS_FetchSecrets(t *testing.T) {
},
},
}
svidFetcher := func(
ctx context.Context,
log *slog.Logger,
localBundle *spiffebundle.Bundle,
svidRequests []config.SVIDRequest) ([]*workloadpb.X509SVID, error) {
if len(svidRequests) != 2 {
return nil, trace.BadParameter("expected 2 svids requested")
}
return []*workloadpb.X509SVID{
{
SpiffeId: "spiffe://example.com/default",
X509Svid: []byte("CERT-spiffe://example.com/default"),
X509SvidKey: []byte("KEY-spiffe://example.com/default"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
{
SpiffeId: "spiffe://example.com/second",
X509Svid: []byte("CERT-spiffe://example.com/second"),
X509SvidKey: []byte("KEY-spiffe://example.com/second"),
Bundle: workloadidentity.MarshalX509Bundle(localBundle.X509Bundle()),
},
}, nil
}
botConfig := &config.BotConfig{
RenewalInterval: time.Minute,
}
cfg := &config.SPIFFEWorkloadAPIService{
SVIDs: []config.SVIDRequestWithRules{
{
SVIDRequest: config.SVIDRequest{
Path: "/default",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/second",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &uid,
},
},
},
},
{
SVIDRequest: config.SVIDRequest{
Path: "/not-matching",
},
Rules: []config.SVIDRequestRule{
{
Unix: config.SVIDRequestRuleUnix{
UID: &notUID,
},
},
},
},
},
}

tests := []struct {
name string
@@ -231,12 +174,10 @@ func TestSDS_FetchSecrets(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
sds := &spiffeSDSHandler{
log: log,
cfg: cfg,
botCfg: botConfig,

trustBundleCache: mockBundleCache,
clientAuthenticator: clientAuthenticator,
svidFetcher: svidFetcher,
}

req := &discoveryv3pb.DiscoveryRequest{
108 changes: 54 additions & 54 deletions lib/tbot/service_spiffe_workload_api_test.go
Original file line number Diff line number Diff line change
@@ -34,9 +34,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/tool/teleport/testenv"
)
@@ -52,7 +52,7 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) {
log := utils.NewSlogLoggerForTests()
tests := []struct {
name string
att workloadattest.Attestation
att *workloadidentityv1pb.WorkloadAttrs
in []config.SVIDRequestWithRules
want []config.SVIDRequest
}{
@@ -81,12 +81,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) {
},
{
name: "no rules with attestation",
att: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
att: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
UID: 1000,
GID: 1001,
PID: 1002,
Uid: 1000,
Gid: 1001,
Pid: 1002,
},
},
in: []config.SVIDRequestWithRules{
@@ -112,15 +112,15 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) {
},
{
name: "no rules with attestation",
att: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
att: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
// We don't expect that workloadattest will ever return
// Attested: false and include UID/PID/GID but we want to
// ensure we handle this by failing regardless.
Attested: false,
UID: 1000,
GID: 1001,
PID: 1002,
Uid: 1000,
Gid: 1001,
Pid: 1002,
},
},
in: []config.SVIDRequestWithRules{
@@ -141,12 +141,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) {
},
{
name: "no matching rules with attestation",
att: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
att: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
UID: 1000,
GID: 1001,
PID: 1002,
Uid: 1000,
Gid: 1001,
Pid: 1002,
},
},
in: []config.SVIDRequestWithRules{
@@ -220,12 +220,12 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests(t *testing.T) {
},
{
name: "some matching rules with uds",
att: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
att: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
UID: 1000,
GID: 1001,
PID: 1002,
Uid: 1000,
Gid: 1001,
Pid: 1002,
},
},
in: []config.SVIDRequestWithRules{
@@ -290,8 +290,8 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
log := utils.NewSlogLoggerForTests()
tests := []struct {
field string
matching workloadattest.Attestation
nonMatching workloadattest.Attestation
matching *workloadidentityv1pb.WorkloadAttrs
nonMatching *workloadidentityv1pb.WorkloadAttrs
rule config.SVIDRequestRule
}{
{
@@ -301,16 +301,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
PID: ptr(1000),
},
},
matching: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
matching: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
PID: 1000,
Pid: 1000,
},
},
nonMatching: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
nonMatching: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
PID: 200,
Pid: 200,
},
},
},
@@ -321,16 +321,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
UID: ptr(1000),
},
},
matching: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
matching: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
UID: 1000,
Uid: 1000,
},
},
nonMatching: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
nonMatching: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
UID: 200,
Uid: 200,
},
},
},
@@ -341,16 +341,16 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
GID: ptr(1000),
},
},
matching: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
matching: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
GID: 1000,
Gid: 1000,
},
},
nonMatching: workloadattest.Attestation{
Unix: workloadattest.UnixAttestation{
nonMatching: &workloadidentityv1pb.WorkloadAttrs{
Unix: &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
GID: 200,
Gid: 200,
},
},
},
@@ -361,14 +361,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
Namespace: "foo",
},
},
matching: workloadattest.Attestation{
Kubernetes: workloadattest.KubernetesAttestation{
matching: &workloadidentityv1pb.WorkloadAttrs{
Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
Namespace: "foo",
},
},
nonMatching: workloadattest.Attestation{
Kubernetes: workloadattest.KubernetesAttestation{
nonMatching: &workloadidentityv1pb.WorkloadAttrs{
Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
Namespace: "bar",
},
@@ -381,14 +381,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
ServiceAccount: "foo",
},
},
matching: workloadattest.Attestation{
Kubernetes: workloadattest.KubernetesAttestation{
matching: &workloadidentityv1pb.WorkloadAttrs{
Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
ServiceAccount: "foo",
},
},
nonMatching: workloadattest.Attestation{
Kubernetes: workloadattest.KubernetesAttestation{
nonMatching: &workloadidentityv1pb.WorkloadAttrs{
Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
ServiceAccount: "bar",
},
@@ -401,14 +401,14 @@ func TestSPIFFEWorkloadAPIService_filterSVIDRequests_field(t *testing.T) {
PodName: "foo",
},
},
matching: workloadattest.Attestation{
Kubernetes: workloadattest.KubernetesAttestation{
matching: &workloadidentityv1pb.WorkloadAttrs{
Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
PodName: "foo",
},
},
nonMatching: workloadattest.Attestation{
Kubernetes: workloadattest.KubernetesAttestation{
nonMatching: &workloadidentityv1pb.WorkloadAttrs{
Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
PodName: "bar",
},
48 changes: 11 additions & 37 deletions lib/tbot/workloadidentity/workloadattest/attest.go
Original file line number Diff line number Diff line change
@@ -23,32 +23,9 @@ import (
"log/slog"

"github.com/gravitational/trace"
)

// Attestation holds the results of the attestation process carried out on a
// PID by the attestor.
//
// The zero value of this type indicates that no attestation was performed or
// was successful.
type Attestation struct {
Unix UnixAttestation
Kubernetes KubernetesAttestation
}

// LogValue implements slog.LogValue to provide a nicely formatted set of
// log keys for a given attestation.
func (a Attestation) LogValue() slog.Value {
return slog.GroupValue(
slog.Attr{
Key: "unix",
Value: a.Unix.LogValue(),
},
slog.Attr{
Key: "kubernetes",
Value: a.Kubernetes.LogValue(),
},
)
}
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
)

type attestor[T any] interface {
Attest(ctx context.Context, pid int) (T, error)
@@ -58,8 +35,8 @@ type attestor[T any] interface {
// key information about the process.
type Attestor struct {
log *slog.Logger
kubernetes attestor[KubernetesAttestation]
unix attestor[UnixAttestation]
kubernetes attestor[*workloadidentityv1pb.WorkloadAttrsKubernetes]
unix attestor[*workloadidentityv1pb.WorkloadAttrsUnix]
}

// Config is the configuration for Attestor
@@ -83,30 +60,27 @@ func NewAttestor(log *slog.Logger, cfg Config) (*Attestor, error) {
return att, nil
}

func (a *Attestor) Attest(ctx context.Context, pid int) (Attestation, error) {
func (a *Attestor) Attest(ctx context.Context, pid int) (*workloadidentityv1pb.WorkloadAttrs, error) {
a.log.DebugContext(ctx, "Starting workload attestation", "pid", pid)
defer a.log.DebugContext(ctx, "Finished workload attestation", "pid", pid)

var (
att Attestation
err error
)

var err error
attrs := &workloadidentityv1pb.WorkloadAttrs{}
// We always perform the unix attestation first
att.Unix, err = a.unix.Attest(ctx, pid)
attrs.Unix, err = a.unix.Attest(ctx, pid)
if err != nil {
return att, err
return attrs, err
}

// Then we can perform the optionally configured attestations
// For these, failure is soft. If it fails, we log, but still return the
// successfully attested data.
if a.kubernetes != nil {
att.Kubernetes, err = a.kubernetes.Attest(ctx, pid)
attrs.Kubernetes, err = a.kubernetes.Attest(ctx, pid)
if err != nil {
a.log.WarnContext(ctx, "Failed to perform Kubernetes workload attestation", "error", err)
}
}

return att, nil
return attrs, nil
}
45 changes: 0 additions & 45 deletions lib/tbot/workloadidentity/workloadattest/kubernetes.go
Original file line number Diff line number Diff line change
@@ -19,54 +19,9 @@
package workloadattest

import (
"log/slog"

"github.com/gravitational/trace"
)

// KubernetesAttestation holds the Kubernetes pod information retrieved from
// the workload attestation process.
type KubernetesAttestation struct {
// Attested is true if the PID was successfully attested to a Kubernetes
// pod. This indicates the validity of the rest of the fields.
Attested bool
// Namespace is the namespace of the pod.
Namespace string
// ServiceAccount is the service account of the pod.
ServiceAccount string
// PodName is the name of the pod.
PodName string
// PodUID is the UID of the pod.
PodUID string
// Labels is a map of labels on the pod.
Labels map[string]string
}

// LogValue implements slog.LogValue to provide a nicely formatted set of
// log keys for a given attestation.
func (a KubernetesAttestation) LogValue() slog.Value {
values := []slog.Attr{
slog.Bool("attested", a.Attested),
}
if a.Attested {
labels := []slog.Attr{}
for k, v := range a.Labels {
labels = append(labels, slog.String(k, v))
}
values = append(values,
slog.String("namespace", a.Namespace),
slog.String("service_account", a.ServiceAccount),
slog.String("pod_name", a.PodName),
slog.String("pod_uid", a.PodUID),
slog.Attr{
Key: "labels",
Value: slog.GroupValue(labels...),
},
)
}
return slog.GroupValue(values...)
}

// KubernetesAttestorConfig holds the configuration for the KubernetesAttestor.
type KubernetesAttestorConfig struct {
// Enabled is true if the KubernetesAttestor is enabled. If false,
12 changes: 7 additions & 5 deletions lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go
Original file line number Diff line number Diff line change
@@ -41,6 +41,8 @@ import (
"github.com/gravitational/trace"
v1 "k8s.io/api/core/v1"
"k8s.io/utils/mount"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
)

// KubernetesAttestor attests a workload to a Kubernetes pod.
@@ -75,27 +77,27 @@ func NewKubernetesAttestor(cfg KubernetesAttestorConfig, log *slog.Logger) *Kube

// Attest resolves the Kubernetes pod information from the
// PID of the workload.
func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (KubernetesAttestation, error) {
func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (*workloadidentityv1pb.WorkloadAttrsKubernetes, error) {
a.log.DebugContext(ctx, "Starting Kubernetes workload attestation", "pid", pid)

podID, containerID, err := a.getContainerAndPodID(pid)
if err != nil {
return KubernetesAttestation{}, trace.Wrap(err, "determining pod and container ID")
return nil, trace.Wrap(err, "determining pod and container ID")
}
a.log.DebugContext(ctx, "Found pod and container ID", "pod_id", podID, "container_id", containerID)

pod, err := a.getPodForID(ctx, podID)
if err != nil {
return KubernetesAttestation{}, trace.Wrap(err, "finding pod by ID")
return nil, trace.Wrap(err, "finding pod by ID")
}
a.log.DebugContext(ctx, "Found pod", "pod_name", pod.Name)

att := KubernetesAttestation{
att := &workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
Namespace: pod.Namespace,
ServiceAccount: pod.Spec.ServiceAccountName,
PodName: pod.Name,
PodUID: string(pod.UID),
PodUid: string(pod.UID),
Labels: pod.Labels,
}
a.log.DebugContext(ctx, "Finished Kubernetes workload attestation", "attestation", att)
Original file line number Diff line number Diff line change
@@ -31,12 +31,15 @@ import (
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/lib/utils"
)

@@ -165,14 +168,14 @@ func TestKubernetesAttestor_Attest(t *testing.T) {

att, err := attestor.Attest(ctx, mockPID)
assert.NoError(t, err)
assert.Equal(t, KubernetesAttestation{
assert.Empty(t, cmp.Diff(&workloadidentityv1pb.WorkloadAttrsKubernetes{
Attested: true,
ServiceAccount: "my-service-account",
Namespace: "default",
PodName: "my-pod",
PodUID: mockPodID,
PodUid: mockPodID,
Labels: map[string]string{
"my-label": "my-label-value",
},
}, att)
}, att, protocmp.Transform()))
}
Original file line number Diff line number Diff line change
@@ -25,14 +25,16 @@ import (
"log/slog"

"github.com/gravitational/trace"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
)

// WindowsKubernetesAttestor is the windows stub for KubernetesAttestor.
type WindowsKubernetesAttestor struct {
}

func (a WindowsKubernetesAttestor) Attest(_ context.Context, _ int) (KubernetesAttestation, error) {
return KubernetesAttestation{}, trace.NotImplemented("kubernetes attestation is not supported on windows")
func (a WindowsKubernetesAttestor) Attest(_ context.Context, _ int) (*workloadidentityv1pb.WorkloadAttrsKubernetes, error) {
return nil, trace.NotImplemented("kubernetes attestation is not supported on windows")
}

// NewKubernetesAttestor creates a new KubernetesAttestor.
57 changes: 14 additions & 43 deletions lib/tbot/workloadidentity/workloadattest/unix.go
Original file line number Diff line number Diff line change
@@ -20,41 +20,12 @@ package workloadattest

import (
"context"
"log/slog"

"github.com/gravitational/trace"
"github.com/shirou/gopsutil/v4/process"
)

// UnixAttestation holds the Unix process information retrieved from the
// workload attestation process.
type UnixAttestation struct {
// Attested is true if the PID was successfully attested to a Unix
// process. This indicates the validity of the rest of the fields.
Attested bool
// PID is the process ID of the attested process.
PID int
// UID is the primary user ID of the attested process.
UID int
// GID is the primary group ID of the attested process.
GID int
}

// LogValue implements slog.LogValue to provide a nicely formatted set of
// log keys for a given attestation.
func (a UnixAttestation) LogValue() slog.Value {
values := []slog.Attr{
slog.Bool("attested", a.Attested),
}
if a.Attested {
values = append(values,
slog.Int("uid", a.UID),
slog.Int("pid", a.PID),
slog.Int("gid", a.GID),
)
}
return slog.GroupValue(values...)
}
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
)

// UnixAttestor attests a process id to a Unix process.
type UnixAttestor struct {
@@ -66,35 +37,35 @@ func NewUnixAttestor() *UnixAttestor {
}

// Attest attests a process id to a Unix process.
func (a *UnixAttestor) Attest(ctx context.Context, pid int) (UnixAttestation, error) {
func (a *UnixAttestor) Attest(ctx context.Context, pid int) (*workloadidentityv1pb.WorkloadAttrsUnix, error) {
p, err := process.NewProcessWithContext(ctx, int32(pid))
if err != nil {
return UnixAttestation{}, trace.Wrap(err, "getting process")
return nil, trace.Wrap(err, "getting process")
}

att := UnixAttestation{
att := &workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
PID: pid,
Pid: int32(pid),
}
// On Linux:
// Real, effective, saved, and file system GIDs
// On Darwin:
// Effective, effective, saved GIDs
gids, err := p.Gids()
if err != nil {
return UnixAttestation{}, trace.Wrap(err, "getting gids")
return nil, trace.Wrap(err, "getting gids")
}
// We generally want to select the effective GID.
switch len(gids) {
case 0:
// error as none returned
return UnixAttestation{}, trace.BadParameter("no gids returned")
return nil, trace.BadParameter("no gids returned")
case 1:
// Only one GID - this is unusual but let's take it.
att.GID = int(gids[0])
att.Gid = gids[0]
default:
// Take the index 1 entry as this is effective
att.GID = int(gids[1])
att.Gid = gids[1]
}

// On Linux:
@@ -103,19 +74,19 @@ func (a *UnixAttestor) Attest(ctx context.Context, pid int) (UnixAttestation, er
// Effective
uids, err := p.Uids()
if err != nil {
return UnixAttestation{}, trace.Wrap(err, "getting uids")
return nil, trace.Wrap(err, "getting uids")
}
// We generally want to select the effective GID.
switch len(uids) {
case 0:
// error as none returned
return UnixAttestation{}, trace.BadParameter("no uids returned")
return nil, trace.BadParameter("no uids returned")
case 1:
// Only one UID, we expect this on Darwin to be the Effective UID
att.UID = int(uids[0])
att.Uid = uids[0]
default:
// Take the index 1 entry as this is Effective UID on Linux
att.UID = int(uids[1])
att.Uid = uids[1]
}

return att, nil
14 changes: 9 additions & 5 deletions lib/tbot/workloadidentity/workloadattest/unix_test.go
Original file line number Diff line number Diff line change
@@ -23,7 +23,11 @@ import (
"os"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
)

func TestUnixAttestor_Attest(t *testing.T) {
@@ -37,10 +41,10 @@ func TestUnixAttestor_Attest(t *testing.T) {
attestor := NewUnixAttestor()
att, err := attestor.Attest(ctx, pid)
require.NoError(t, err)
require.Equal(t, UnixAttestation{
require.Empty(t, cmp.Diff(&workloadidentityv1pb.WorkloadAttrsUnix{
Attested: true,
PID: pid,
UID: uid,
GID: gid,
}, att)
Pid: int32(pid),
Uid: uint32(uid),
Gid: uint32(gid),
}, att, protocmp.Transform()))
}