Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbattirola committed Aug 28, 2024
1 parent aa70c2b commit b610f4c
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 21 deletions.
5 changes: 3 additions & 2 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ const (
IncludedResourceModeAll = "all"
// DefaultLicenseWatchInterval is the default time in which the license watcher
// should ping the auth server for new cluster features
DefaultLicenseWatchInterval = time.Second * 1 // time.Minute * 2
DefaultLicenseWatchInterval = time.Minute * 5
)

// healthCheckAppServerFunc defines a function used to perform a health check
Expand Down Expand Up @@ -179,6 +179,7 @@ type Handler struct {
// featureWatcherStop is a channel used to emit a stop signal to the
// license watcher goroutine
featureWatcherStop chan struct{}
featureWatcherOnce sync.Once
}

// HandlerOption is a functional argument - an option that can be passed
Expand Down Expand Up @@ -1642,7 +1643,7 @@ func (h *Handler) getWebConfig(w http.ResponseWriter, r *http.Request, p httprou
}
}

clusterFeatures := h.clusterFeatures
clusterFeatures := h.GetClusterFeatures()

// get tunnel address to display on cloud instances
tunnelPublicAddr := ""
Expand Down
40 changes: 21 additions & 19 deletions lib/web/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,30 @@ func (h *Handler) GetClusterFeatures() proto.Features {
}

func (h *Handler) startFeaturesWatcher() {
ticker := h.clock.NewTicker(h.cfg.LicenseWatchInterval)
h.log.WithField("interval", h.cfg.LicenseWatchInterval).Info("Proxy handler features watcher has started")
ctx := context.Background()
h.featureWatcherOnce.Do(func() {
ticker := h.clock.NewTicker(h.cfg.LicenseWatchInterval)
h.log.WithField("interval", h.cfg.LicenseWatchInterval).Info("Proxy handler features watcher has started")
ctx := context.Background()

defer ticker.Stop()
for {
select {
case <-ticker.Chan():
h.log.Info("Pinging auth server for features")
f, err := h.GetProxyClient().Ping(ctx)
if err != nil {
h.log.WithError(err).Error("Failed fetching features")
continue
}
defer ticker.Stop()
for {
select {
case <-ticker.Chan():
h.log.Info("Pinging auth server for features")
f, err := h.cfg.ProxyClient.Ping(ctx)
if err != nil {
h.log.WithError(err).Error("Auth server ping failed")
continue
}

h.SetClusterFeatures(*f.ServerFeatures)
h.log.WithField("features", f.ServerFeatures).Infof("Done updating proxy features: %+v", f)
case <-h.featureWatcherStop:
h.log.Info("Feature service has stopped")
return
h.SetClusterFeatures(*f.ServerFeatures)
h.log.Debug("Done updating proxy features")
case <-h.featureWatcherStop:
h.log.Info("Feature service has stopped")
return
}
}
}
})
}

func (h *Handler) stopFeaturesWatcher() {
Expand Down
145 changes: 145 additions & 0 deletions lib/web/features_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package web

import (
"context"
"log/slog"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/entitlements"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestFeaturesWatcher(t *testing.T) {
clock := clockwork.NewFakeClock()
mockedFeatures := proto.Features{
Kubernetes: true,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}

handler := &Handler{
cfg: Config{
LicenseWatchInterval: 100 * time.Millisecond,
ProxyClient: &mockedPingTestProxy{
mockedPing: func(ctx context.Context) (proto.PingResponse, error) {
return proto.PingResponse{
ServerFeatures: &mockedFeatures,
}, nil
},
},
},
clock: clock,
clusterFeatures: proto.Features{},
featureWatcherStop: make(chan struct{}),
log: newPackageLogger(),
logger: slog.Default().With(teleport.ComponentKey, teleport.ComponentWeb),
}

// before running the watcher, features should match the value passed to the handler
requireFeatures(t, clock, proto.Features{}, handler.GetClusterFeatures)

go handler.startFeaturesWatcher()
clock.BlockUntil(1)

// after starting the watcher, handler.GetClusterFeatures should return
// values matching the client's response
features := proto.Features{
Kubernetes: true,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)
expected := utils.CloneProtoMsg(&features)
requireFeatures(t, clock, *expected, handler.GetClusterFeatures)

// update values once again and check if the features are properly updated
features = proto.Features{
Kubernetes: false,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)
mockedFeatures = features
expected = utils.CloneProtoMsg(&features)
requireFeatures(t, clock, *expected, handler.GetClusterFeatures)

// test updating entitlements
features = proto.Features{
Kubernetes: true,
Entitlements: map[string]*proto.EntitlementInfo{
string(entitlements.ExternalAuditStorage): {Enabled: true},
string(entitlements.AccessLists): {Enabled: true},
string(entitlements.AccessMonitoring): {Enabled: true},
string(entitlements.App): {Enabled: true},
string(entitlements.CloudAuditLogRetention): {Enabled: true},
},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)
mockedFeatures = features

expected = &proto.Features{
Kubernetes: true,
Entitlements: map[string]*proto.EntitlementInfo{
string(entitlements.ExternalAuditStorage): {Enabled: true},
string(entitlements.AccessLists): {Enabled: true},
string(entitlements.AccessMonitoring): {Enabled: true},
string(entitlements.App): {Enabled: true},
string(entitlements.CloudAuditLogRetention): {Enabled: true},
},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(expected)
requireFeatures(t, clock, *expected, handler.GetClusterFeatures)

// stop watcher and ensure it stops updating features
handler.stopFeaturesWatcher()
features = proto.Features{
Kubernetes: !features.Kubernetes,
App: !features.App,
DB: true,
Entitlements: map[string]*proto.EntitlementInfo{},
AccessRequests: &proto.AccessRequestsFeature{},
}
entitlements.SupportEntitlementsCompatibility(&features)
mockedFeatures = features
expected = utils.CloneProtoMsg(&features)
// assert the handler never get these last features as the watcher is stopped
neverFeatures(t, clock, *expected, handler.GetClusterFeatures)
}

// requireFeatures is a helper function that advances the clock, then
// calls `getFeatures` every 100ms for up to 1 second, until it
// returns the expected result (`want`).
func requireFeatures(t *testing.T, fakeClock clockwork.FakeClock, want proto.Features, getFeatures func() proto.Features) {
t.Helper()

// Advance the clock so the service fetch and stores features
fakeClock.Advance(1 * time.Second)

require.EventuallyWithT(t, func(c *assert.CollectT) {
diff := cmp.Diff(want, getFeatures())
if !assert.Empty(c, diff) {
t.Logf("Feature diff (-want +got):\n%s", diff)
}
}, 1*time.Second, time.Millisecond*100)
}

// neverFeatures is a helper function that advances the clock, then
// calls `getFeatures` every 100ms for up to 1 second. If at some point `getFeatures`
// returns `doNotWant`, the test fails.
func neverFeatures(t *testing.T, fakeClock clockwork.FakeClock, doNotWant proto.Features, getFeatures func() proto.Features) {
t.Helper()

fakeClock.Advance(1 * time.Second)
require.Never(t, func() bool {
return cmp.Diff(doNotWant, getFeatures()) == ""
}, 1*time.Second, time.Millisecond*100)
}

0 comments on commit b610f4c

Please sign in to comment.