diff --git a/rpc.go b/rpc.go index 7c5856bb..bbc6c219 100644 --- a/rpc.go +++ b/rpc.go @@ -13,7 +13,6 @@ import ( "io" "math" "strconv" - "sync" "time" log "github.com/sirupsen/logrus" @@ -242,28 +241,29 @@ func (c *client) SendBatch(ctx context.Context, batch []hrpc.Call) ( } // Send each group of RPCs to region client to be executed. - var ( - wg sync.WaitGroup - - mu sync.Mutex - fail bool - ) - wg.Add(len(rpcByClient)) + type clientAndRPCs struct { + client hrpc.RegionClient + rpcs []hrpc.Call + } + // keep track of the order requests are queued so that we can wait + // for their responses in the same order. + cAndRs := make([]clientAndRPCs, 0, len(rpcByClient)) for client, rpcs := range rpcByClient { - go func(client hrpc.RegionClient, rpcs []hrpc.Call) { - defer wg.Done() - client.QueueBatch(ctx, rpcs) - ctx, sp := observability.StartSpan(ctx, "waitForResult") - defer sp.End() - ok := c.waitForCompletion(ctx, client, rpcs, res, rpcToRes) + client.QueueBatch(ctx, rpcs) + cAndRs = append(cAndRs, clientAndRPCs{client, rpcs}) + } + + var fail bool + func() { // func used to scope the span + ctx, sp := observability.StartSpan(ctx, "waitForResult") + defer sp.End() + for _, cAndR := range cAndRs { + ok := c.waitForCompletion(ctx, cAndR.client, cAndR.rpcs, res, rpcToRes) if !ok { - mu.Lock() fail = true - mu.Unlock() } - }(client, rpcs) - } - wg.Wait() + } + }() allOK = !fail return res, allOK