diff --git a/supervisor.go b/supervisor.go index 4e1ae9a..57b7fef 100644 --- a/supervisor.go +++ b/supervisor.go @@ -177,17 +177,16 @@ type timedSupervisor struct { errFn ErrorFunc t *time.Ticker - resetCh chan struct{} + commits uint32 running uint32 } // NewTimedSupervisor returns a supervisor that commits automatically. func NewTimedSupervisor(inner Supervisor, d time.Duration, errFn ErrorFunc) Supervisor { return &timedSupervisor{ - inner: inner, - d: d, - errFn: errFn, - resetCh: make(chan struct{}, 1), + inner: inner, + d: d, + errFn: errFn, } } @@ -208,17 +207,16 @@ func (s *timedSupervisor) Start() error { s.t = time.NewTicker(s.d) go func() { - for { - select { - case <-s.t.C: - err := s.inner.Commit(nil) - if err != nil { - s.errFn(err) - } - - case <-s.resetCh: - s.t.Stop() - s.t = time.NewTicker(s.d) + for range s.t.C { + // If there was a commit triggered "manually" by a Committer, skip a single timed commit. + if atomic.LoadUint32(&s.commits) > 0 { + atomic.StoreUint32(&s.commits, 0) + continue + } + + err := s.inner.Commit(nil) + if err != nil { + s.errFn(err) } } }() @@ -245,11 +243,13 @@ func (s *timedSupervisor) Commit(caller Processor) error { return ErrNotRunning } + // Increment the commit count + atomic.AddUint32(&s.commits, 1) + err := s.inner.Commit(caller) if err != nil { return err } - s.resetCh <- struct{}{} return nil } diff --git a/supervisor_test.go b/supervisor_test.go index 589b260..c643a46 100644 --- a/supervisor_test.go +++ b/supervisor_test.go @@ -381,22 +381,22 @@ func TestTimedSupervisor_Commit(t *testing.T) { inner.AssertCalled(t, "Commit", caller) } -func TestTimedSupervisor_CommitResetsTimer(t *testing.T) { +func TestTimedSupervisor_ManualCommitSkipsTimedCommit(t *testing.T) { caller := new(MockProcessor) inner := new(MockSupervisor) inner.On("Start").Return(nil) - inner.On("Commit", mock.Anything).Return(nil) + inner.On("Commit", caller).Return(nil) inner.On("Close").Return(nil) - supervisor := streams.NewTimedSupervisor(inner, 10*time.Millisecond, nil) + supervisor := streams.NewTimedSupervisor(inner, 5*time.Millisecond, nil) _ = supervisor.Start() defer supervisor.Close() - time.Sleep(5 * time.Millisecond) + time.Sleep(2 * time.Millisecond) _ = supervisor.Commit(caller) - time.Sleep(5 * time.Millisecond) + time.Sleep(4 * time.Millisecond) inner.AssertNumberOfCalls(t, "Commit", 1) }