From afa3d424e82fe5da183b5b60fb620a217ebfdcdf Mon Sep 17 00:00:00 2001 From: EstebanOlmedo Date: Mon, 11 Sep 2023 11:25:04 -0600 Subject: [PATCH] Add CallWrapper interface for type-safe calls Add a new exported interface `CallWrapper` which allow users to use `InOrder` and `After` with generated type-safe mock types. --- gomock/call.go | 13 +++++++++---- mockgen/internal/tests/typed_after_in_order/mock.go | 8 ++++---- mockgen/mockgen.go | 4 ++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/gomock/call.go b/gomock/call.go index 6b944f8..6eb2511 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -44,6 +44,11 @@ type Call struct { actions []func([]any) []any } +// CallWrapper is an interface for retrieving a *Call. +type CallWrapper interface { + GetCall() *Call +} + // newCall creates a *Call. It requires the method type in order to support // unexported methods. func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args ...any) *Call { @@ -79,8 +84,8 @@ func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} } -// GetCall returns the current `*Call` instance, this is needed to fulfill the -// interface that `InOrder` and `After` receive as parameter. +// GetCall returns the current `*Call` instance, this is needed to implement +// the CallWrapper interface. func (c *Call) GetCall() *Call { return c } @@ -294,7 +299,7 @@ func (c *Call) isPreReq(other *Call) bool { } // After declares that the call may only match after preReq has been exhausted. -func (c *Call) After(prq interface{GetCall() *Call}) *Call { +func (c *Call) After(prq CallWrapper) *Call { preReq := prq.GetCall() c.t.Helper() @@ -442,7 +447,7 @@ func (c *Call) call() []func([]any) []any { } // InOrder declares that the given calls should occur in order. -func InOrder(calls ...interface{GetCall() *Call}) { +func InOrder(calls ...CallWrapper) { for i := 1; i < len(calls); i++ { calls[i].GetCall().After(calls[i-1]) } diff --git a/mockgen/internal/tests/typed_after_in_order/mock.go b/mockgen/internal/tests/typed_after_in_order/mock.go index 16838f9..8df9905 100644 --- a/mockgen/internal/tests/typed_after_in_order/mock.go +++ b/mockgen/internal/tests/typed_after_in_order/mock.go @@ -75,13 +75,13 @@ func (c *AnimalFeedCall) DoAndReturn(f func(string) error) *AnimalFeedCall { return c } -// Call rewrite *gomock.Call.GetCall +// GetCall is needed to implement gomock.CallWrapper func (c *AnimalFeedCall) GetCall() *gomock.Call { return c.Call } // After rewrite *gomock.Call.After -func (c *AnimalFeedCall) After(prq interface{ GetCall() *gomock.Call }) *gomock.Call { +func (c *AnimalFeedCall) After(prq gomock.CallWrapper) *gomock.Call { return c.Call.After(prq) } @@ -123,12 +123,12 @@ func (c *AnimalGetNoiseCall) DoAndReturn(f func() string) *AnimalGetNoiseCall { return c } -// Call rewrite *gomock.Call.GetCall +// GetCall is needed to implement gomock.CallWrapper func (c *AnimalGetNoiseCall) GetCall() *gomock.Call { return c.Call } // After rewrite *gomock.Call.After -func (c *AnimalGetNoiseCall) After(prq interface{ GetCall() *gomock.Call }) *gomock.Call { +func (c *AnimalGetNoiseCall) After(prq gomock.CallWrapper) *gomock.Call { return c.Call.After(prq) } diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index afb2e30..9f30ecb 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -700,7 +700,7 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model g.out() g.p("}") - g.p("// Call rewrite *gomock.Call.GetCall") + g.p("// GetCall is needed to implement gomock.CallWrapper") g.p("func (%s *%sCall%s) GetCall() *gomock.Call {", idRecv, recvStructName, shortTp) g.in() g.p("return %s.Call", idRecv) @@ -708,7 +708,7 @@ func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model g.p("}") g.p("// After rewrite *gomock.Call.After") - g.p("func (%s *%sCall%s) After(prq interface{ GetCall() *gomock.Call }) *gomock.Call {", idRecv, recvStructName, shortTp) + g.p("func (%s *%sCall%s) After(prq gomock.CallWrapper) *gomock.Call {", idRecv, recvStructName, shortTp) g.in() g.p("return %s.Call.After(prq)", idRecv) g.out()