diff --git a/balancer.go b/balancer.go index 474ec61..733c4da 100644 --- a/balancer.go +++ b/balancer.go @@ -47,19 +47,23 @@ func newBalancer( picker func(prev picker.Picker, allConns conn.Conns) picker.Picker, checker health.Checker, pool connPool, + roundTripperMaxLifetime time.Duration, ) *balancer { ctx, cancel := context.WithCancel(ctx) balancer := &balancer{ - ctx: ctx, - cancel: cancel, - pool: pool, - newPicker: picker, - healthChecker: checker, - resolverUpdates: make(chan struct{}, 1), - closed: make(chan struct{}), - connInfo: map[conn.Conn]connInfo{}, - clock: internal.NewRealClock(), + ctx: ctx, + cancel: cancel, + pool: pool, + newPicker: picker, + healthChecker: checker, + roundTripperMaxLifetime: roundTripperMaxLifetime, + resolverUpdates: make(chan struct{}, 1), + recycleConns: make(chan struct{}, 1), + closed: make(chan struct{}), + connInfo: map[conn.Conn]connInfo{}, + clock: internal.NewRealClock(), } + balancer.connManager.updateFunc = balancer.updateConns return balancer } @@ -81,6 +85,8 @@ type balancer struct { healthChecker health.Checker // +checklocksignore: mu is not required, but happens to always be held. connManager connManager + roundTripperMaxLifetime time.Duration // +checklocksignore: mu is not required, but happens to always be held. + // NB: only set from tests updateHook func([]resolver.Address, []conn.Conn) @@ -93,6 +99,7 @@ type balancer struct { latestAddrs atomic.Pointer[[]resolver.Address] latestErr atomic.Pointer[error] resolverUpdates chan struct{} + recycleConns chan struct{} clock internal.Clock mu sync.Mutex @@ -106,6 +113,8 @@ type balancer struct { connInfo map[conn.Conn]connInfo // +checklocks:mu reresolveLastCall time.Time + // +checklocks:mu + connsToRecycle []conn.Conn } func (b *balancer) UpdateHealthState(connection conn.Conn, state health.State) { @@ -195,7 +204,7 @@ func (b *balancer) receiveAddrs(ctx context.Context) { for key, info := range b.connInfo { delete(b.connInfo, key) closer := info.closeChecker - info.cancelWarm() + info.cancel() if closer != nil { grp.Go(doClose(closer)) } @@ -213,6 +222,15 @@ func (b *balancer) receiveAddrs(ctx context.Context) { select { case <-ctx.Done(): return + case <-b.recycleConns: + b.mu.Lock() + connsToRecycle := b.connsToRecycle + b.connsToRecycle = nil + b.mu.Unlock() + if len(connsToRecycle) > 0 { + b.connManager.recycleConns(connsToRecycle) + } + case <-b.resolverUpdates: addrs := b.latestAddrs.Load() if addrs == nil { @@ -242,7 +260,7 @@ func (b *balancer) receiveAddrs(ctx context.Context) { if len(*addrs) > 0 { addrsClone := make([]resolver.Address, len(*addrs)) copy(addrsClone, *addrs) - b.connManager.reconcileAddresses(addrsClone, b.updateConns) + b.connManager.reconcileAddresses(addrsClone) } } } @@ -282,7 +300,7 @@ func (b *balancer) updateConns(newAddrs []resolver.Address, removeConns []conn.C // and omit it from newConns info := b.connInfo[existing] delete(b.connInfo, existing) - info.cancelWarm() + info.cancel() if info.closeChecker != nil { _ = info.closeChecker.Close() } @@ -291,8 +309,16 @@ func (b *balancer) updateConns(newAddrs []resolver.Address, removeConns []conn.C newConns = append(newConns, existing) } newConns = append(newConns, addConns...) - for i := range addConns { - connection := addConns[i] + b.initConnInfoLocked(addConns) + b.conns = newConns + b.newPickerLocked() + return addConns +} + +// +checklocks:b.mu +func (b *balancer) initConnInfoLocked(conns []conn.Conn) { + for i := range conns { + connection := conns[i] connCtx, connCancel := context.WithCancel(b.ctx) healthChecker := b.healthChecker.New(connCtx, connection, b) go func() { @@ -301,11 +327,18 @@ func (b *balancer) updateConns(newAddrs []resolver.Address, removeConns []conn.C b.warmedUp(connection) } }() - b.connInfo[connection] = connInfo{closeChecker: healthChecker, cancelWarm: connCancel} + cancel := connCancel + if b.roundTripperMaxLifetime != 0 { + timer := time.AfterFunc(b.roundTripperMaxLifetime, func() { + b.recycle(connection) + }) + cancel = func() { + connCancel() + timer.Stop() + } + } + b.connInfo[connection] = connInfo{closeChecker: healthChecker, cancel: cancel} } - b.conns = newConns - b.newPickerLocked() - return addConns } // +checklocks:b.mu @@ -392,19 +425,36 @@ func (b *balancer) setErrorPickerLocked(err error) { b.pool.UpdatePicker(picker.ErrorPicker(err), false) } +func (b *balancer) recycle(c conn.Conn) { + b.mu.Lock() + defer b.mu.Unlock() + b.connsToRecycle = append(b.connsToRecycle, c) + // Notify goroutine that there is a connection to recycle. + select { + case b.recycleConns <- struct{}{}: + default: + } +} + type connInfo struct { - state health.State - warm bool - cancelWarm context.CancelFunc + state health.State + warm bool + + // Cancels any in-progress warm-up and also cancels any timer + // for recycling the connection. Invoked when the connection + // is closed. + cancel context.CancelFunc closeChecker io.Closer } type connManager struct { - // only modified by a single goroutine, so mu is not necessary + // only used by a single goroutine, so no mutex necessary connsByAddr map[string][]conn.Conn + + updateFunc func([]resolver.Address, []conn.Conn) []conn.Conn } -func (c *connManager) reconcileAddresses(addrs []resolver.Address, updateFunc func([]resolver.Address, []conn.Conn) []conn.Conn) { +func (c *connManager) reconcileAddresses(addrs []resolver.Address) { // TODO: future extension: make connection establishing strategy configurable // (which would allow more sophisticated connection strategies in the face // of, for example, layer-4 load balancers) @@ -446,13 +496,56 @@ func (c *connManager) reconcileAddresses(addrs []resolver.Address, updateFunc fu newAddrs = append(newAddrs, want...) } + c.connsByAddr = remaining + c.doUpdate(newAddrs, toRemove) +} + +func (c *connManager) doUpdate(newAddrs []resolver.Address, toRemove []conn.Conn) { // we make a single call to update connections in batch to create a single // new picker (avoids potential picker churn from making one change at a time) - newConns := updateFunc(newAddrs, toRemove) - // add newConns to remaining to compute new set of connections - for _, c := range newConns { - hostPort := c.Address().HostPort - remaining[hostPort] = append(remaining[hostPort], c) + newConns := c.updateFunc(newAddrs, toRemove) + // add newConns to set of connections + for _, cn := range newConns { + hostPort := cn.Address().HostPort + c.connsByAddr[hostPort] = append(c.connsByAddr[hostPort], cn) } - c.connsByAddr = remaining +} + +func (c *connManager) recycleConns(connsToRecycle []conn.Conn) { + var needToCompact bool + for i, cn := range connsToRecycle { + addr := cn.Address().HostPort + existing := c.connsByAddr[addr] + var found bool + for i, existingConn := range existing { + if existingConn == cn { + found = true + // remove cn from the slice + copy(existing[i:], existing[i+1:]) + c.connsByAddr[addr] = existing[:len(existing)-1] + break + } + } + if !found { + // this connection has already been closed/removed + connsToRecycle[i] = nil + needToCompact = true + } + } + if needToCompact { + i := 0 + for _, cn := range connsToRecycle { + if cn != nil { + connsToRecycle[i] = cn + i++ + } + } + connsToRecycle = connsToRecycle[:i] + } + newAddrs := make([]resolver.Address, len(connsToRecycle)) + for i := range connsToRecycle { + newAddrs[i] = connsToRecycle[i].Address() + } + + c.doUpdate(newAddrs, connsToRecycle) } diff --git a/balancer_test.go b/balancer_test.go index 393eef1..ac36654 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -32,7 +32,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestConnManager(t *testing.T) { +func TestConnManager_ReconcileAddresses(t *testing.T) { t.Parallel() type updateReq struct { newAddrs []resolver.Address @@ -72,7 +72,7 @@ func TestConnManager(t *testing.T) { return updateReq{} } } - var connMgr connManager + connMgr := connManager{updateFunc: testUpdate} addrs := []resolver.Address{ {HostPort: "1.2.3.1"}, {HostPort: "1.2.3.2"}, @@ -81,7 +81,7 @@ func TestConnManager(t *testing.T) { {HostPort: "1.2.3.5"}, {HostPort: "1.2.3.6"}, } - connMgr.reconcileAddresses(addrs, testUpdate) + connMgr.reconcileAddresses(addrs) latestUpdate := getLatestUpdate() require.Equal(t, addrs, latestUpdate.newAddrs) require.Empty(t, latestUpdate.removeConns) @@ -107,7 +107,7 @@ func TestConnManager(t *testing.T) { {HostPort: "1.2.3.3"}, {HostPort: "1.2.3.3"}, } - connMgr.reconcileAddresses(addrs, testUpdate) + connMgr.reconcileAddresses(addrs) latestUpdate = getLatestUpdate() // 10 entries needed, and we start with 3. So we need // 2x more of each, but 3x of the first @@ -149,7 +149,7 @@ func TestConnManager(t *testing.T) { {HostPort: "1.2.3.3", Attributes: attrs3a}, {HostPort: "1.2.3.3", Attributes: attrs3b}, } - connMgr.reconcileAddresses(addrs, testUpdate) + connMgr.reconcileAddresses(addrs) latestUpdate = getLatestUpdate() require.Empty(t, latestUpdate.newAddrs) require.Equal(t, []conn.Conn{conn1i8, conn1i9, conn2i11, conn3i13}, latestUpdate.removeConns) @@ -184,7 +184,7 @@ func TestConnManager(t *testing.T) { {HostPort: "1.2.3.6"}, {HostPort: "1.2.3.8"}, } - connMgr.reconcileAddresses(addrs, testUpdate) + connMgr.reconcileAddresses(addrs) // Wanted to create 1.2.3.4, 1.2.3.6, and 1.2.3.8, but only first two created. latestUpdate = getLatestUpdate() require.Equal(t, addrs[1:], latestUpdate.newAddrs) @@ -197,16 +197,21 @@ func TestConnManager(t *testing.T) { {HostPort: "1.2.3.6"}, {HostPort: "1.2.3.8"}, } - connMgr.reconcileAddresses(addrs, testUpdate) + connMgr.reconcileAddresses(addrs) latestUpdate = getLatestUpdate() require.Equal(t, addrs[3:], latestUpdate.newAddrs) require.Empty(t, latestUpdate.removeConns) } +func TestConnManager_RecycleConns(t *testing.T) { + t.Parallel() + // TODO +} + func TestBalancer_BasicConnManagement(t *testing.T) { t.Parallel() pool := balancertesting.NewFakeConnPool() - balancer := newBalancer(context.Background(), balancertesting.NewFakePicker, health.NopChecker, pool) + balancer := newBalancer(context.Background(), balancertesting.NewFakePicker, health.NopChecker, pool, 0) balancer.updateHook = balancertesting.DeterministicReconciler balancer.start() // Initial resolve @@ -285,7 +290,7 @@ func TestBalancer_HealthChecking(t *testing.T) { return ctx.Err() } } - balancer := newBalancer(context.Background(), balancertesting.NewFakePicker, checker, pool) + balancer := newBalancer(context.Background(), balancertesting.NewFakePicker, checker, pool, 0) balancer.updateHook = balancertesting.DeterministicReconciler balancer.start() @@ -384,13 +389,13 @@ func TestBalancer_HealthChecking(t *testing.T) { require.Empty(t, checkers) } -func TestDefaultBalancer_Reresolve(t *testing.T) { +func TestBalancer_Reresolve(t *testing.T) { t.Parallel() checker := balancertesting.NewFakeHealthChecker() clock := clocktest.NewFakeClock() pool := balancertesting.NewFakeConnPool() - balancer := newBalancer(context.Background(), balancertesting.NewFakePicker, checker, pool) + balancer := newBalancer(context.Background(), balancertesting.NewFakePicker, checker, pool, 0) balancer.updateHook = balancertesting.DeterministicReconciler balancer.clock = clock balancer.start() @@ -425,6 +430,11 @@ func TestDefaultBalancer_Reresolve(t *testing.T) { require.Empty(t, checkers) } +func TestBalancer_RoundTripperMaxLifetime(t *testing.T) { + t.Parallel() + // TODO +} + func awaitPickerUpdate(t *testing.T, pool *balancertesting.FakeConnPool, warm bool, addrs []resolver.Address, indexes []int) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second) diff --git a/client.go b/client.go index 0751f8b..752b303 100644 --- a/client.go +++ b/client.go @@ -133,14 +133,35 @@ func WithRootContext(ctx context.Context) ClientOption { // If zero or no WithIdleTransportTimeout option is used, a default of // 15 minutes will be used. // -// To prevent some transports from being closed due to being idle, use -// WithKeepWarmTargets. +// To prevent transports from being closed due to being idle, set an +// arbitrarily large timeout (i.e. math.MaxInt64) or use WithAllowBackendTarget. func WithIdleTransportTimeout(duration time.Duration) ClientOption { return clientOptionFunc(func(opts *clientOptions) { opts.idleTransportTimeout = duration }) } +// WithRoundTripperMaxLifetime configures a limit for how long a single +// round tripper (or "leaf" transport) will be used. If no option is given, +// round trippers are retained indefinitely, until their parent transport +// is closed (which can happen if the transport is idle; see +// WithIdleTransportTimeout). When a round tripper reaches its maximum +// lifetime, a new one is first created to replace it. Any in-progress +// operations are allowed to finish, but no new operations will use it. +// +// This function is mainly useful when the target host is a layer-4 proxy. +// In this situation, it is possible for multiple round trippers to all +// get connected to the same backend host, resulting in poor load +// distribution. With a lifetime limit, a single round tripper will get +// "recycled", and its replacement is likely to be connected to a +// different backend host. So when a transport gets into a scenario where +// it has poor backend diversity, the lifetime limit allows it to self-heal. +func WithRoundTripperMaxLifetime(duration time.Duration) ClientOption { + return clientOptionFunc(func(opts *clientOptions) { + opts.roundTripperMaxLifetime = duration + }) +} + // WithTransport returns an option that uses a custom transport to create // [http.RoundTripper] instances for the given URL scheme. This allows // one to override the default implementations for "http", "https", and @@ -354,6 +375,7 @@ func (f clientOptionFunc) applyToClient(opts *clientOptions) { type clientOptions struct { rootCtx context.Context //nolint:containedctx idleTransportTimeout time.Duration + roundTripperMaxLifetime time.Duration schemes map[string]Transport redirectFunc func(req *http.Request, via []*http.Request) error allowedTarget *target diff --git a/transport.go b/transport.go index 10d6046..b2ca9f6 100644 --- a/transport.go +++ b/transport.go @@ -343,6 +343,7 @@ func (m *mainTransport) getOrCreatePool(dest target) (*transportPool, error) { m.clientOptions.resolver, m.clientOptions.newPicker, m.clientOptions.healthChecker, + m.clientOptions.roundTripperMaxLifetime, dest, applyTimeout, schemeConf, @@ -488,6 +489,7 @@ func newTransportPool( res resolver.Resolver, newPicker func(prev picker.Picker, allConns conn.Conns) picker.Picker, checker health.Checker, + roundTripperMaxLifetime time.Duration, dest target, applyTimeout func(ctx context.Context) (context.Context, context.CancelFunc), transport Transport, @@ -507,7 +509,7 @@ func newTransportPool( onClose: onClose, } pool.warmCond = sync.NewCond(&pool.mu) - pool.balancer = newBalancer(ctx, newPicker, checker, pool) + pool.balancer = newBalancer(ctx, newPicker, checker, pool, roundTripperMaxLifetime) pool.resolver = res.New(ctx, dest.scheme, dest.hostPort, pool.balancer, reresolve) pool.balancer.start() return pool