diff --git a/lib/inventory/controller.go b/lib/inventory/controller.go index ead22598011e4..e00ab8ad9783d 100644 --- a/lib/inventory/controller.go +++ b/lib/inventory/controller.go @@ -61,6 +61,7 @@ const ( appKeepAliveOk testEvent = "app-keep-alive-ok" appKeepAliveErr testEvent = "app-keep-alive-err" + appKeepAliveDel testEvent = "app-keep-alive-del" appUpsertOk testEvent = "app-upsert-ok" appUpsertErr testEvent = "app-upsert-err" @@ -75,6 +76,8 @@ const ( handlerStart = "handler-start" handlerClose = "handler-close" + + keepAliveTick = "keep-alive-tick" ) // intervalKey is used to uniquely identify the subintervals registered with the interval.MultiInterval @@ -622,6 +625,10 @@ func (c *Controller) handleAgentMetadata(handle *upstreamHandle, m proto.Upstrea } func (c *Controller) keepAliveServer(handle *upstreamHandle, now time.Time) error { + // always fire off 'tick' event after keepalive processing to ensure + // that waiting for N ticks maps intuitively to waiting for N keepalive + // processing steps. + defer c.testEvent(keepAliveTick) if err := c.keepAliveSSHServer(handle, now); err != nil { return trace.Wrap(err) } @@ -647,14 +654,15 @@ func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) e srv.keepAliveErrs++ handle.appServers[name] = srv - shouldClose := srv.keepAliveErrs > c.maxKeepAliveErrs - - log.Warnf("Failed to keep alive app server %q: %v (count=%d, closing=%v).", handle.Hello().ServerID, err, srv.keepAliveErrs, shouldClose) + shouldRemove := srv.keepAliveErrs > c.maxKeepAliveErrs + log.Warnf("Failed to keep alive app server %q: %v (count=%d, removing=%v).", handle.Hello().ServerID, err, srv.keepAliveErrs, shouldRemove) - if shouldClose { - return trace.Errorf("failed to keep alive app server: %v", err) + if shouldRemove { + c.testEvent(appKeepAliveDel) + delete(handle.appServers, name) } } else { + srv.keepAliveErrs = 0 c.testEvent(appKeepAliveOk) } } else if srv.retryUpsert { @@ -697,6 +705,7 @@ func (c *Controller) keepAliveSSHServer(handle *upstreamHandle, now time.Time) e return trace.Errorf("failed to keep alive ssh server: %v", err) } } else { + handle.sshServer.keepAliveErrs = 0 c.testEvent(sshKeepAliveOk) } } else if handle.sshServer.retryUpsert { diff --git a/lib/inventory/controller_test.go b/lib/inventory/controller_test.go index 405841f363b53..e76ae4b7008b4 100644 --- a/lib/inventory/controller_test.go +++ b/lib/inventory/controller_test.go @@ -279,6 +279,7 @@ func TestSSHServerBasics(t *testing.T) { // an app service. func TestAppServerBasics(t *testing.T) { const serverID = "test-server" + const appCount = 3 t.Parallel() @@ -311,7 +312,7 @@ func TestAppServerBasics(t *testing.T) { require.True(t, ok) // send a fake app server heartbeat - for i := 0; i < 3; i++ { + for i := 0; i < appCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ AppServer: &types.AppServerV3{ Metadata: types.Metadata{ @@ -353,7 +354,7 @@ func TestAppServerBasics(t *testing.T) { deny(appUpsertErr, handlerClose), ) - for i := 0; i < 3; i++ { + for i := 0; i < appCount; i++ { err := downstream.Send(ctx, proto.InventoryHeartbeat{ AppServer: &types.AppServerV3{ Metadata: types.Metadata{ @@ -402,6 +403,38 @@ func TestAppServerBasics(t *testing.T) { _, err := handle.Ping(pingCtx, 1) require.NoError(t, err) + // ensure that local app keepalive states have reset to healthy by waiting + // on a full cycle+ worth of keepalives without errors. + awaitEvents(t, events, + expect(keepAliveTick, keepAliveTick), + deny(appKeepAliveErr, handlerClose), + ) + + // set up to induce enough consecutive keepalive errors to cause removal + // of server-side keepalive state. + auth.mu.Lock() + auth.failKeepAlives = 3 * appCount + auth.mu.Unlock() + + // expect that all app keepalives fail, then the app is removed. + var expectedEvents []testEvent + for i := 0; i < appCount; i++ { + expectedEvents = append(expectedEvents, []testEvent{appKeepAliveErr, appKeepAliveErr, appKeepAliveErr, appKeepAliveDel}...) + } + + // wait for failed keepalives to trigger removal + awaitEvents(t, events, + expect(expectedEvents...), + deny(handlerClose), + ) + + // verify that further keepalive ticks to not result in attempts to keepalive + // apps (successful or not). + awaitEvents(t, events, + expect(keepAliveTick, keepAliveTick, keepAliveTick), + deny(appKeepAliveOk, appKeepAliveErr, handlerClose), + ) + // set up to induce enough consecutive errors to cause stream closure auth.mu.Lock() auth.failUpserts = 5 @@ -736,7 +769,7 @@ func awaitEvents(t *testing.T, ch <-chan testEvent, opts ...eventOption) { opt(&options) } - timeout := time.After(time.Second * 5) + timeout := time.After(time.Second * 30) for { if len(options.expect) == 0 { return