diff --git a/shutter.go b/shutter.go index 49c6960..fcbedb4 100644 --- a/shutter.go +++ b/shutter.go @@ -1,6 +1,9 @@ package shutter -import "sync" +import ( + "errors" + "sync" +) type Shutter struct { lock sync.Mutex // shutdown lock @@ -26,6 +29,25 @@ func NewWithCallback(f func(error)) *Shutter { return s } +var ErrShutterWasAlreadyDown = errors.New("saferun was called on an already-shutdown shutter") + +// SafeRun allows you to run a function only if the shutter is not down yet, +// with the assurance that the it will not run its callback functions +// during the execution of your function. +// +// This is useful to prevent race conditions, where the func given to "SafeRun" +// should increase a counter and the func given to OnShutdown should decrease it. +// +// WARNING: never call Shutdown from within your SafeRun function, it will deadlock. +func (s *Shutter) SafeRun(fn func() error) (err error) { + s.lock.Lock() + defer s.lock.Unlock() + if s.IsDown() { + return ErrShutterWasAlreadyDown + } + return fn() +} + func (s *Shutter) Shutdown(err error) { var execute = false s.once.Do(func() { diff --git a/shutter_test.go b/shutter_test.go index f35b193..00e2db5 100644 --- a/shutter_test.go +++ b/shutter_test.go @@ -3,6 +3,7 @@ package shutter import ( "errors" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -33,3 +34,69 @@ func TestMultiCallbacks(t *testing.T) { s.Shutdown(nil) assert.Equal(t, 2, a) } + +func TestSafeRunAlreadyShutdown(t *testing.T) { + s := New() + a := 0 + s.OnShutdown(func(_ error) { + a-- + }) + s.Shutdown(nil) + err := s.SafeRun(func() error { + a++ + return nil + }) + + assert.Equal(t, -1, a) + assert.Equal(t, ErrShutterWasAlreadyDown, err) +} + +func TestSafeRunNotShutdown(t *testing.T) { + s := New() + a := 0 + s.OnShutdown(func(_ error) { + a-- + }) + err := s.SafeRun(func() error { + a++ + return nil + }) + assert.NoError(t, err) + s.Shutdown(nil) + assert.Equal(t, 0, a) +} + +func TestShutdownDuringSafeRun(t *testing.T) { + s := New() + + a := 0 + s.OnShutdown(func(_ error) { + a-- + }) + + var err error + inSafeRunCh := make(chan interface{}) + shutdownCalled := make(chan interface{}) + + go func() { + err = s.SafeRun(func() error { + close(inSafeRunCh) + select { + case <-shutdownCalled: + t.Errorf("Shutdown was called and completed while in SafeRun") + case <-time.After(50 * time.Millisecond): + return nil + } + return nil + }) + }() + + <-inSafeRunCh + go func() { + s.Shutdown(nil) + close(shutdownCalled) + }() + assert.NoError(t, err) + <-shutdownCalled + assert.Equal(t, -1, a) +}