From 5c2a107f681049604e9489e77a197c39fdc92702 Mon Sep 17 00:00:00 2001 From: James Pickett Date: Mon, 9 Dec 2024 12:58:46 -0800 Subject: [PATCH] add multiplicative ticker, use in runners --- ee/secureenclaverunner/secureenclaverunner.go | 9 +- ee/tpmrunner/tpmrunner.go | 9 +- pkg/backoff/ticker.go | 77 +++++++++++++ pkg/backoff/ticker_test.go | 105 ++++++++++++++++++ 4 files changed, 186 insertions(+), 14 deletions(-) create mode 100644 pkg/backoff/ticker.go create mode 100644 pkg/backoff/ticker_test.go diff --git a/ee/secureenclaverunner/secureenclaverunner.go b/ee/secureenclaverunner/secureenclaverunner.go index c8a39af9c..fec41a49d 100644 --- a/ee/secureenclaverunner/secureenclaverunner.go +++ b/ee/secureenclaverunner/secureenclaverunner.go @@ -20,6 +20,7 @@ import ( "github.com/kolide/launcher/ee/agent/types" "github.com/kolide/launcher/ee/consoleuser" + "github.com/kolide/launcher/pkg/backoff" "github.com/kolide/launcher/pkg/traces" ) @@ -76,8 +77,7 @@ func (ser *secureEnclaveRunner) Execute() error { } } - currentRetryInterval, maxRetryInterval := 1*time.Second, 1*time.Minute - retryTicker := time.NewTicker(currentRetryInterval) + retryTicker := backoff.NewMultiplicativeTicker(time.Second, time.Minute) defer retryTicker.Stop() for { @@ -87,11 +87,6 @@ func (ser *secureEnclaveRunner) Execute() error { "getting current console user key, will retry", "err", err, ) - - if currentRetryInterval < maxRetryInterval { - currentRetryInterval += time.Second - retryTicker.Reset(currentRetryInterval) - } } else { retryTicker.Stop() } diff --git a/ee/tpmrunner/tpmrunner.go b/ee/tpmrunner/tpmrunner.go index 5588ac81b..d58c44e3f 100644 --- a/ee/tpmrunner/tpmrunner.go +++ b/ee/tpmrunner/tpmrunner.go @@ -11,6 +11,7 @@ import ( "github.com/kolide/krypto/pkg/tpm" "github.com/kolide/launcher/ee/agent/types" + "github.com/kolide/launcher/pkg/backoff" "github.com/kolide/launcher/pkg/traces" ) @@ -68,8 +69,7 @@ func New(ctx context.Context, slogger *slog.Logger, store types.GetterSetterDele // Public returns the public key of the current console user // creating and peristing a new one if needed func (tr *tpmRunner) Execute() error { - currentRetryInterval, maxRetryInterval := 1*time.Second, 1*time.Minute - retryTicker := time.NewTicker(currentRetryInterval) + retryTicker := backoff.NewMultiplicativeTicker(time.Second, time.Minute) defer retryTicker.Stop() for { @@ -80,11 +80,6 @@ func (tr *tpmRunner) Execute() error { "creating tpm signer, will retry", "err", err, ) - - if currentRetryInterval < maxRetryInterval { - currentRetryInterval += time.Second - retryTicker.Reset(currentRetryInterval) - } } else { tr.signer = signer retryTicker.Stop() diff --git a/pkg/backoff/ticker.go b/pkg/backoff/ticker.go new file mode 100644 index 000000000..d0333e00e --- /dev/null +++ b/pkg/backoff/ticker.go @@ -0,0 +1,77 @@ +package backoff + +import ( + "time" +) + +// NewMultiplicativeTicker returns a ticker where each interval = baseDuration * ticks until maxDuration is reached. +func NewMultiplicativeTicker(baseDuration, maxDuration time.Duration) *ticker { + return newTicker(newMultiplicativeCounter(baseDuration, maxDuration)) +} + +type ticker struct { + C chan time.Time + baseTicker *time.Ticker + stoppedChan chan struct{} + stopped bool + durationCounter durationCounter +} + +func newTicker(durationCounter *durationCounter) *ticker { + thisTicker := &ticker{ + C: make(chan time.Time), + stoppedChan: make(chan struct{}), + durationCounter: *durationCounter, + } + + thisTicker.baseTicker = time.NewTicker(thisTicker.durationCounter.next()) + + go func() { + for { + select { + case t := <-thisTicker.baseTicker.C: + thisTicker.baseTicker.Reset(thisTicker.durationCounter.next()) + thisTicker.C <- t + case <-thisTicker.stoppedChan: + thisTicker.baseTicker.Stop() + return + } + } + }() + + return thisTicker +} + +func (t *ticker) Stop() { + if t.stopped { + return + } + + t.stopped = true + close(t.stoppedChan) +} + +type durationCounter struct { + count int + baseInterval, maxInterval time.Duration + calcNext func(count int, baseDuration time.Duration) time.Duration +} + +func (dc *durationCounter) next() time.Duration { + dc.count++ + interval := dc.calcNext(dc.count, dc.baseInterval) + if interval > dc.maxInterval { + return dc.maxInterval + } + return interval +} + +func newMultiplicativeCounter(baseDuration, maxDuration time.Duration) *durationCounter { + return &durationCounter{ + baseInterval: baseDuration, + maxInterval: maxDuration, + calcNext: func(count int, baseInterval time.Duration) time.Duration { + return baseInterval * time.Duration(count) + }, + } +} diff --git a/pkg/backoff/ticker_test.go b/pkg/backoff/ticker_test.go new file mode 100644 index 000000000..73eda66d9 --- /dev/null +++ b/pkg/backoff/ticker_test.go @@ -0,0 +1,105 @@ +package backoff + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestMultiplicativeCounter(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseInterval time.Duration + maxInterval time.Duration + expected []time.Duration + }{ + { + name: "seconds", + baseInterval: time.Second, + maxInterval: 5 * time.Second, + expected: []time.Duration{ + time.Second, // 1s + 2 * time.Second, // 2s + 3 * time.Second, // 3s + 4 * time.Second, // 4s + 5 * time.Second, // 5s (max interval) + 5 * time.Second, // capped at max interval + }, + }, + { + name: "minutes", + baseInterval: time.Minute, + maxInterval: 3 * time.Minute, + expected: []time.Duration{ + time.Minute, // 1m + 2 * time.Minute, // 2m + 3 * time.Minute, // 3m (max interval) + 3 * time.Minute, // capped at max interval + 3 * time.Minute, // capped at max interval + }, + }, + { + name: "combo", + baseInterval: (1 * time.Minute) + (30 * time.Second), + maxInterval: 5 * time.Minute, + expected: []time.Duration{ + (1 * time.Minute) + (30 * time.Second), // 1m30s + 2 * ((1 * time.Minute) + (30 * time.Second)), // 3m + 3 * ((1 * time.Minute) + (30 * time.Second)), // 4m30s + 5 * time.Minute, // 5m + 5 * time.Minute, // 5m + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ec := newMultiplicativeCounter(tt.baseInterval, tt.maxInterval) + for _, expected := range tt.expected { + require.Equal(t, expected, ec.next()) + } + }) + } +} + +// TestMultiplicativeTicker tests the NewMultiplicativeTicker and its behavior. +func TestMultiplicativeTicker(t *testing.T) { + baseTime := 100 * time.Millisecond + maxTime := 500 * time.Millisecond + + tk := NewMultiplicativeTicker(baseTime, maxTime) + defer tk.Stop() + + expectedDurations := []time.Duration{ + 100 * time.Millisecond, + 200 * time.Millisecond, + 300 * time.Millisecond, + 400 * time.Millisecond, + 500 * time.Millisecond, // maxTime limit + 500 * time.Millisecond, // maxTime limit + } + + buffer := 25 * time.Millisecond + + for _, expected := range expectedDurations { + start := time.Now() + + select { + case <-tk.C: + require.WithinDuration(t, start, time.Now(), expected+buffer) + case <-time.After(maxTime + buffer): + t.Fatalf("ticker did not send event in expected time: %v", expected) + } + } + + // stop the ticker + tk.Stop() + + // call stop again to make sure no panic (same as stdlib ticker) + tk.Stop() +}