diff --git a/rpc.go b/rpc.go index 7c5856bb..0fa13614 100644 --- a/rpc.go +++ b/rpc.go @@ -14,6 +14,7 @@ import ( "math" "strconv" "sync" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -243,28 +244,35 @@ func (c *client) SendBatch(ctx context.Context, batch []hrpc.Call) ( // Send each group of RPCs to region client to be executed. var ( - wg sync.WaitGroup - - mu sync.Mutex - fail bool + wg sync.WaitGroup + fail uint32 // set atomically to 1 if one of the clients fail ) - wg.Add(len(rpcByClient)) + clientCount := len(rpcByClient) + wg.Add(clientCount) + + sendBlocking := func(client hrpc.RegionClient, rpcs []hrpc.Call) { + defer wg.Done() + client.QueueBatch(ctx, rpcs) + ctx, sp := observability.StartSpan(ctx, "waitForResult") + defer sp.End() + ok := c.waitForCompletion(ctx, client, rpcs, res, rpcToRes) + if !ok { + atomic.StoreUint32(&fail, 1) + } + } + + i := 0 for client, rpcs := range rpcByClient { - go func(client hrpc.RegionClient, rpcs []hrpc.Call) { - defer wg.Done() - client.QueueBatch(ctx, rpcs) - ctx, sp := observability.StartSpan(ctx, "waitForResult") - defer sp.End() - ok := c.waitForCompletion(ctx, client, rpcs, res, rpcToRes) - if !ok { - mu.Lock() - fail = true - mu.Unlock() - } - }(client, rpcs) + if i++; i < clientCount { + go sendBlocking(client, rpcs) + } else { + // Small optimization: don't launch a goroutine for the + // last client + sendBlocking(client, rpcs) + } } wg.Wait() - allOK = !fail + allOK = fail == 0 return res, allOK }