diff --git a/gomock/controller.go b/gomock/controller.go index 6846d0d..9d17a2f 100644 --- a/gomock/controller.go +++ b/gomock/controller.go @@ -246,6 +246,8 @@ func (ctrl *Controller) Finish() { // Satisfied returns whether all expected calls bound to this Controller have been satisfied. // Calling Finish is then guaranteed to not fail due to missing calls. func (ctrl *Controller) Satisfied() bool { + ctrl.mu.Lock() + defer ctrl.mu.Unlock() return ctrl.expectedCalls.Satisfied() } diff --git a/sample/concurrent/README.md b/sample/concurrent/README.md new file mode 100644 index 0000000..89e3f5d --- /dev/null +++ b/sample/concurrent/README.md @@ -0,0 +1,9 @@ +# Concurrent + +This directory contains an example of executing mock calls concurrently. + +To run the test, + +```bash +go test -race go.uber.org/mock/sample/concurrent +``` diff --git a/sample/concurrent/concurrent_test.go b/sample/concurrent/concurrent_test.go index 3dc8822..b660bbc 100644 --- a/sample/concurrent/concurrent_test.go +++ b/sample/concurrent/concurrent_test.go @@ -2,7 +2,9 @@ package concurrent import ( "context" + "fmt" "testing" + "time" "go.uber.org/mock/gomock" mock "go.uber.org/mock/sample/concurrent/mock" @@ -22,6 +24,26 @@ func call(ctx context.Context, m Math) (int, error) { } } +func waitForMocks(ctx context.Context, ctrl *gomock.Controller) error { + ticker := time.NewTicker(1 * time.Millisecond) + defer ticker.Stop() + + timeout := time.After(3 * time.Millisecond) + + for { + select { + case <-ticker.C: + if ctrl.Satisfied() { + return nil + } + case <-timeout: + return fmt.Errorf("timeout waiting for mocks to be satisfied") + case <-ctx.Done(): + return fmt.Errorf("context cancelled") + } + } +} + // TestConcurrentFails is expected to fail (and is disabled). It // demonstrates how to use gomock.WithContext to interrupt the test // from a different goroutine. @@ -42,3 +64,26 @@ func TestConcurrentWorks(t *testing.T) { t.Error("call failed:", err) } } + +func TestCancelWhenMocksSatisfied(t *testing.T) { + ctrl, ctx := gomock.WithContext(context.Background(), t) + m := mock.NewMockMath(ctrl) + m.EXPECT().Sum(1, 2).Return(3).MinTimes(1) + + // This goroutine calls the mock and then waits for the context to be done. + go func() { + for { + m.Sum(1, 2) + select { + case <-ctx.Done(): + return + } + } + }() + + // waitForMocks spawns another goroutine which blocks until ctrl.Satisfied() is true. + if err := waitForMocks(ctx, ctrl); err != nil { + t.Error("call failed:", err) + } + ctrl.Finish() +}