diff --git a/lib/srv/heartbeatv2.go b/lib/srv/heartbeatv2.go index 8ce07fd666a3b..bdbc1a4de8423 100644 --- a/lib/srv/heartbeatv2.go +++ b/lib/srv/heartbeatv2.go @@ -59,6 +59,9 @@ type HeartbeatV2Config[T any] struct { OnHeartbeat func(error) // AnnounceInterval is the interval at which heartbeats are attempted (optional). AnnounceInterval time.Duration + // DisruptionAnnounceInterval is the interval at which heartbeats are attempted when + // if there was a disuption in the control stream since the last heartbeat (optional). + DisruptionAnnounceInterval time.Duration // PollInterval is the interval at which checks for change are performed (optional). PollInterval time.Duration } @@ -99,9 +102,10 @@ func NewSSHServerHeartbeat(cfg HeartbeatV2Config[*types.ServerV2]) (*HeartbeatV2 } return newHeartbeatV2(cfg.InventoryHandle, inner, heartbeatV2Config{ - onHeartbeatInner: cfg.OnHeartbeat, - announceInterval: cfg.AnnounceInterval, - pollInterval: cfg.PollInterval, + onHeartbeatInner: cfg.OnHeartbeat, + announceInterval: cfg.AnnounceInterval, + disruptionAnnounceInterval: cfg.DisruptionAnnounceInterval, + pollInterval: cfg.PollInterval, }), nil } @@ -118,9 +122,10 @@ func NewAppServerHeartbeat(cfg HeartbeatV2Config[*types.AppServerV3]) (*Heartbea } return newHeartbeatV2(cfg.InventoryHandle, inner, heartbeatV2Config{ - onHeartbeatInner: cfg.OnHeartbeat, - announceInterval: cfg.AnnounceInterval, - pollInterval: cfg.PollInterval, + onHeartbeatInner: cfg.OnHeartbeat, + announceInterval: cfg.AnnounceInterval, + disruptionAnnounceInterval: cfg.DisruptionAnnounceInterval, + pollInterval: cfg.PollInterval, }), nil } @@ -208,9 +213,10 @@ type HeartbeatV2 struct { } type heartbeatV2Config struct { - announceInterval time.Duration - pollInterval time.Duration - onHeartbeatInner func(error) + announceInterval time.Duration + disruptionAnnounceInterval time.Duration + pollInterval time.Duration + onHeartbeatInner func(error) // -- below values only used in tests @@ -226,6 +232,11 @@ func (c *heartbeatV2Config) SetDefaults() { // from the average of ~5m30s that was used for V1 ssh server heartbeats. c.announceInterval = 2 * (apidefaults.ServerAnnounceTTL / 3) } + if c.disruptionAnnounceInterval == 0 { + // if there was a disruption in the control stream, we want to heartbeat a bit + // sooner in case the disruption affected the most recent announce's success. + c.disruptionAnnounceInterval = 2 * (c.announceInterval / 3) + } if c.pollInterval == 0 { c.pollInterval = defaults.HeartbeatCheckPeriod } @@ -356,6 +367,21 @@ func (h *HeartbeatV2) runWithSender(sender inventory.DownstreamSender) { h.shouldAnnounce = true } + // in the event of disruption, we want to heartbeat a bit sooner than the normal. + // this helps prevent node heartbeats from getting too stale when auth servers fail + // in a manner that isn't immediately detected by the agent (e.g. deadlock, + // i/o timeout, etc). Since we're heartbeating over a channel, such failure modes + // can sometimes mean that the last announce failed "silently" from our perspective. + if t, ok := h.announce.LastTick(); ok { + elapsed := time.Since(t) + dai := utils.SeventhJitter(h.disruptionAnnounceInterval) + if elapsed >= dai { + h.shouldAnnounce = true + } else { + h.announce.ResetTo(dai - elapsed) + } + } + for { if h.shouldAnnounce { if ok := h.inner.Announce(h.closeContext, sender); ok { diff --git a/lib/utils/interval/interval.go b/lib/utils/interval/interval.go index 67003c98c4416..2bde9f81f66b5 100644 --- a/lib/utils/interval/interval.go +++ b/lib/utils/interval/interval.go @@ -21,6 +21,7 @@ package interval import ( "errors" "sync" + "sync/atomic" "time" "github.com/jonboulle/clockwork" @@ -38,8 +39,9 @@ import ( type Interval struct { cfg Config ch chan time.Time - reset chan struct{} + reset chan time.Duration fire chan struct{} + lastTick atomic.Pointer[time.Time] closeOnce sync.Once done chan struct{} } @@ -88,7 +90,7 @@ func New(cfg Config) *Interval { interval := &Interval{ ch: make(chan time.Time, 1), cfg: cfg, - reset: make(chan struct{}), + reset: make(chan time.Duration), fire: make(chan struct{}), done: make(chan struct{}), } @@ -121,7 +123,15 @@ func (i *Interval) Stop() { // jitter(duration) regardless of current timer progress). func (i *Interval) Reset() { select { - case i.reset <- struct{}{}: + case i.reset <- time.Duration(0): + case <-i.done: + } +} + +// ResetTo resets the interval to the target duration for the next tick. +func (i *Interval) ResetTo(d time.Duration) { + select { + case i.reset <- d: case <-i.done: } } @@ -140,6 +150,20 @@ func (i *Interval) Next() <-chan time.Time { return i.ch } +// LastTick gets the most recent tick if the interval has fired at least once. Note that the +// tick returned by this method is the last *generated* tick, not necessarily the last tick +// that was *observed* by the consumer of the interval. +func (i *Interval) LastTick() (tick time.Time, ok bool) { + if t := i.lastTick.Load(); t != nil { + return *t, true + } + return time.Time{}, false +} + +func (i *Interval) setLastTick(tick time.Time) { + i.lastTick.Store(&tick) +} + // duration gets the duration of the interval. Each call applies the jitter // if one was supplied. func (i *Interval) duration() time.Duration { @@ -163,13 +187,17 @@ func (i *Interval) run(timer clockwork.Timer) { // output channel is set. timer.Reset(i.duration()) ch = i.ch - case <-i.reset: + i.setLastTick(tick) + case d := <-i.reset: // stop and drain timer if !timer.Stop() { <-timer.Chan() } + if d == 0 { + d = i.duration() + } // re-set the timer - timer.Reset(i.duration()) + timer.Reset(d) // ensure we don't send any pending ticks ch = nil case <-i.fire: @@ -182,6 +210,7 @@ func (i *Interval) run(timer clockwork.Timer) { // simulate firing of the timer tick = time.Now() ch = i.ch + i.setLastTick(tick) case ch <- tick: // tick has been sent, set ch back to nil to prevent // double-send and wait for next timer firing diff --git a/lib/utils/interval/interval_test.go b/lib/utils/interval/interval_test.go index 281df018f9919..5fd7ee5f061a2 100644 --- a/lib/utils/interval/interval_test.go +++ b/lib/utils/interval/interval_test.go @@ -27,6 +27,54 @@ import ( "github.com/stretchr/testify/require" ) +// TestLastTick verifies that the LastTick method returns the last tick time as expected. Due to the +// flaky nature of time-based tests runs many cases and passes if >50% of them succeed. +func TestLastTick(t *testing.T) { + const workers = 1_000 + const ticks = 12 + t.Parallel() + + var success, failure atomic.Uint64 + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + interval := New(Config{ + Duration: time.Millisecond * 333, + }) + + _, ok := interval.LastTick() + if ok { + failure.Add(1) + return + } + + for i := 0; i < ticks; i++ { + actual := <-interval.Next() + t, ok := interval.LastTick() + if !ok { + failure.Add(1) + return + } + + if actual != t { + failure.Add(1) + return + } + } + + success.Add(1) + }() + } + + wg.Wait() + + require.Greater(t, success.Load(), failure.Load()) +} + // TestIntervalReset verifies the basic behavior of the interval reset functionality. // Since time based tests tend to be flaky, this test passes if it has a >50% success // rate (i.e. >50% of resets seemed to have actually extended the timer successfully). @@ -83,6 +131,53 @@ func TestIntervalReset(t *testing.T) { require.Greater(t, success.Load(), failure.Load()) } +// TestIntervalResetTo verifies the basic behavior of the interval ResetTo method. +// Since time based tests tend to be flaky, this test passes if it has a >50% success +// rate (i.e. >50% of ResetTo calls seemed to have changed the timer's behavior as expected). +func TestIntervalResetTo(t *testing.T) { + const workers = 1_000 + const ticks = 12 + const longDuration = time.Millisecond * 800 + const shortDuration = time.Millisecond * 200 + t.Parallel() + + var success, failure atomic.Uint64 + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + interval := New(Config{ + Duration: longDuration, + }) + defer interval.Stop() + + start := time.Now() + + for i := 0; i < ticks; i++ { + interval.ResetTo(shortDuration) + <-interval.Next() + } + + elapsed := time.Since(start) + // if the above works completed before the expected minimum time + // to complete all ticks as long ticks, we assume that ResetTo has + // successfully shortened the interval. + if elapsed < longDuration*time.Duration(ticks) { + success.Add(1) + } else { + failure.Add(1) + } + }() + } + + wg.Wait() + + require.Greater(t, success.Load(), failure.Load()) +} + func TestNewNoop(t *testing.T) { t.Parallel() i := NewNoop()