Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbattirola committed Aug 28, 2024
1 parent aa70c2b commit 7ef795d
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 22 deletions.
3 changes: 2 additions & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ const (
IncludedResourceModeAll = "all"
// DefaultLicenseWatchInterval is the default time in which the license watcher
// should ping the auth server for new cluster features
DefaultLicenseWatchInterval = time.Second * 1 // time.Minute * 2
DefaultLicenseWatchInterval = time.Second * 5 // time.Minute * 2
)

// healthCheckAppServerFunc defines a function used to perform a health check
Expand Down Expand Up @@ -179,6 +179,7 @@ type Handler struct {
// featureWatcherStop is a channel used to emit a stop signal to the
// license watcher goroutine
featureWatcherStop chan struct{}
featureWatcherOnce sync.Once
}

// HandlerOption is a functional argument - an option that can be passed
Expand Down
51 changes: 31 additions & 20 deletions lib/web/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,40 @@ func (h *Handler) GetClusterFeatures() proto.Features {
return h.clusterFeatures
}

func (h *Handler) startFeaturesWatcher() {
ticker := h.clock.NewTicker(h.cfg.LicenseWatchInterval)
h.log.WithField("interval", h.cfg.LicenseWatchInterval).Info("Proxy handler features watcher has started")
ctx := context.Background()
// featuresClient is an interface with the Ping method.
// authclient.ClientI client implements this interface.
type featuresClient interface {
Ping(ctx context.Context) (proto.PingResponse, error)
}

defer ticker.Stop()
for {
select {
case <-ticker.Chan():
h.log.Info("Pinging auth server for features")
f, err := h.GetProxyClient().Ping(ctx)
if err != nil {
h.log.WithError(err).Error("Failed fetching features")
continue
}
func (h *Handler) startFeaturesWatcher(client featuresClient) {
h.log.Warn("starting startFeaturesWatcher")
h.featureWatcherOnce.Do(func() {
ticker := h.clock.NewTicker(h.cfg.LicenseWatchInterval)
h.log.WithField("interval", h.cfg.LicenseWatchInterval).Info("Proxy handler features watcher has started")
ctx := context.Background()

h.SetClusterFeatures(*f.ServerFeatures)
h.log.WithField("features", f.ServerFeatures).Infof("Done updating proxy features: %+v", f)
case <-h.featureWatcherStop:
h.log.Info("Feature service has stopped")
return
h.log.Warn("starting startFeaturesWatcher LOOP")
defer ticker.Stop()
for {
select {
case <-ticker.Chan():
h.log.Info("Pinging auth server for features")
f, err := client.Ping(ctx)
if err != nil {
h.log.WithError(err).Error("Failed fetching features")
continue
}

h.SetClusterFeatures(*f.ServerFeatures)
h.log.Debugf("identity: %t, policy: %t", h.GetClusterFeatures().IdentityGovernance, h.GetClusterFeatures().Policy)
// h.log.WithField("features", f.ServerFeatures).Debug("Done updating proxy features")
case <-h.featureWatcherStop:
h.log.Info("Feature service has stopped")
return
}
}
}
})
}

func (h *Handler) stopFeaturesWatcher() {
Expand Down
149 changes: 149 additions & 0 deletions lib/web/features_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package web

import (
"context"
"log/slog"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/entitlements"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// mockFeaturesClient is a mock implementation of the `featuresClient` interface
// that always returns its internal features with support for entitlement compatibility on Ping
type mockFeaturesClient struct {
features proto.Features
}

func (m *mockFeaturesClient) Ping(ctx context.Context) (proto.PingResponse, error) {
return proto.PingResponse{
ServerFeatures: &m.features,
}, nil
}

func TestFeaturesWatcher(t *testing.T) {
clock := clockwork.NewFakeClock()
handler := &Handler{
cfg: Config{
LicenseWatchInterval: 100 * time.Millisecond,
},
clock: clock,
clusterFeatures: proto.Features{},
featureWatcherStop: make(chan struct{}),
log: newPackageLogger(),
logger: slog.Default().With(teleport.ComponentKey, teleport.ComponentWeb),
}

features := proto.Features{
Kubernetes: true,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)

client := mockFeaturesClient{
features,
}
// before running the service, features should match the value passed in
// to the handler
requireFeatures(t, clock, proto.Features{}, handler.GetClusterFeatures)

go handler.startFeaturesWatcher(&client)
clock.BlockUntil(1)
// after starting the service, handler.GetClusterFeatures should return
// values matching the client's response
expected := utils.CloneProtoMsg(&features)
requireFeatures(t, clock, *expected, handler.GetClusterFeatures)

// update values once again and check if the features are properly updated
features = proto.Features{
Kubernetes: false,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)
client.features = features
expected = utils.CloneProtoMsg(&features)
requireFeatures(t, clock, *expected, handler.GetClusterFeatures)

// test entitlements explicitly
features = proto.Features{
Kubernetes: false,
Entitlements: map[string]*proto.EntitlementInfo{
string(entitlements.ExternalAuditStorage): {Enabled: true},
string(entitlements.AccessLists): {Enabled: true},
string(entitlements.AccessMonitoring): {Enabled: true},
string(entitlements.App): {Enabled: true},
string(entitlements.CloudAuditLogRetention): {Enabled: true},
},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)
client.features = features

expected = &proto.Features{
Kubernetes: false,
Entitlements: map[string]*proto.EntitlementInfo{
string(entitlements.ExternalAuditStorage): {Enabled: true},
string(entitlements.AccessLists): {Enabled: true},
string(entitlements.AccessMonitoring): {Enabled: true},
string(entitlements.App): {Enabled: true},
string(entitlements.CloudAuditLogRetention): {Enabled: true},
},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(expected)
requireFeatures(t, clock, *expected, handler.GetClusterFeatures)

// call stop and ensure it stops updating features
handler.stopFeaturesWatcher()
previousFeatures := utils.CloneProtoMsg(&features)
// toggle some features and update the mock client
features = proto.Features{
Kubernetes: !previousFeatures.Kubernetes,
App: !previousFeatures.App,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}
client.features = features
neverFeatures(t, clock, features, handler.GetClusterFeatures)
}

// requireFeatures is a helper function that advances the clock, then
// calls `getFeatures` every 100ms for up to 1 second, until it
// returns the expected result (`want`).
func requireFeatures(t *testing.T, fakeClock clockwork.FakeClock, want proto.Features, getFeatures func() proto.Features) {
t.Helper()

// Advance the clock so the service fetch and stores features
fakeClock.Advance(1 * time.Second)

require.EventuallyWithT(t, func(c *assert.CollectT) {
diff := cmp.Diff(want, getFeatures())
if !assert.Empty(c, diff) {
t.Logf("Feature diff (-want +got):\n%s", diff)
}
}, 1*time.Second, time.Millisecond*100)
}

func neverFeatures(t *testing.T, fakeClock clockwork.FakeClock, want proto.Features, getFeatures func() proto.Features) {
t.Helper()

// Advance the clock so the service fetch and stores features
fakeClock.Advance(1 * time.Second)

require.Never(t, func() bool {
diff := cmp.Diff(want, getFeatures())
if diff == "" {
return true
}
return false
}, 1*time.Second, time.Millisecond*100)
}
2 changes: 1 addition & 1 deletion lib/web/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (s *Server) Serve(l net.Listener) error {
if closed {
return trace.Errorf("serve called on previously closed server")
}
go s.cfg.Handler.handler.startFeaturesWatcher()
go s.cfg.Handler.handler.startFeaturesWatcher(s.cfg.Handler.handler.GetProxyClient())
return trace.Wrap(s.cfg.Server.Serve(l))
}

Expand Down

0 comments on commit 7ef795d

Please sign in to comment.