Skip to content

Commit

Permalink
Fix connected resource counts after keepalive errors (#47931)
Browse files Browse the repository at this point in the history
* Fix connected resource counts after keepalive errors

* Log server_id when cleaning up resources
  • Loading branch information
espadolini authored Oct 25, 2024
1 parent cd93b25 commit 91bb17a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 23 deletions.
4 changes: 2 additions & 2 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
log.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", s)
}
}),
inventory.WithOnDisconnect(func(s string) {
inventory.WithOnDisconnect(func(s string, c int) {
if g, ok := connectedResourceGauges[s]; ok {
g.Dec()
g.Sub(float64(c))
} else {
log.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", s)
}
Expand Down
44 changes: 24 additions & 20 deletions lib/inventory/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ type controllerOptions struct {
maxKeepAliveErrs int
authID string
onConnectFunc func(string)
onDisconnectFunc func(string)
onDisconnectFunc func(string, int)
}

func (options *controllerOptions) SetDefaults() {
Expand All @@ -153,11 +153,11 @@ func (options *controllerOptions) SetDefaults() {
}

if options.onConnectFunc == nil {
options.onConnectFunc = func(s string) {}
options.onConnectFunc = func(string) {}
}

if options.onDisconnectFunc == nil {
options.onDisconnectFunc = func(s string) {}
options.onDisconnectFunc = func(string, int) {}
}
}

Expand All @@ -180,12 +180,12 @@ func WithOnConnect(f func(heartbeatKind string)) ControllerOption {
}
}

// WithOnDisconnect sets a function to be called every time an existing
// instance disconnects from the inventory control stream. The value
// provided to the callback is the keep alive type of the disconnected
// resource. The callback should return quickly so as not to prevent
// processing of heartbeats.
func WithOnDisconnect(f func(heartbeatKind string)) ControllerOption {
// WithOnDisconnect sets a function to be called every time an existing instance
// disconnects from the inventory control stream. The values provided to the
// callback are the keep alive type of the disconnected resource, as well as a
// count of how many resources disconnected at once. The callback should return
// quickly so as not to prevent processing of heartbeats.
func WithOnDisconnect(f func(heartbeatKind string, amount int)) ControllerOption {
return func(opts *controllerOptions) {
opts.onDisconnectFunc = f
}
Expand Down Expand Up @@ -226,7 +226,7 @@ type Controller struct {
usageReporter usagereporter.UsageReporter
testEvents chan testEvent
onConnectFunc func(string)
onDisconnectFunc func(string)
onDisconnectFunc func(string, int)
closeContext context.Context
cancel context.CancelFunc
}
Expand Down Expand Up @@ -351,9 +351,10 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) {
defer func() {
if handle.goodbye.GetDeleteResources() {
log.WithFields(log.Fields{
"apps": len(handle.appServers),
"dbs": len(handle.databaseServers),
"kube": len(handle.kubernetesServers),
"apps": len(handle.appServers),
"dbs": len(handle.databaseServers),
"kube": len(handle.kubernetesServers),
"server_id": handle.Hello().ServerID,
}).Debug("Cleaning up resources in response to instance termination")
for _, app := range handle.appServers {
if err := c.auth.DeleteApplicationServer(c.closeContext, apidefaults.Namespace, app.resource.GetHostID(), app.resource.GetName()); err != nil && !trace.IsNotFound(err) {
Expand Down Expand Up @@ -383,19 +384,19 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) {
handle.ticker.Stop()

if handle.sshServer != nil {
c.onDisconnectFunc(constants.KeepAliveNode)
c.onDisconnectFunc(constants.KeepAliveNode, 1)
}

for range handle.appServers {
c.onDisconnectFunc(constants.KeepAliveApp)
if len(handle.appServers) > 0 {
c.onDisconnectFunc(constants.KeepAliveApp, len(handle.appServers))
}

for range handle.databaseServers {
c.onDisconnectFunc(constants.KeepAliveDatabase)
if len(handle.databaseServers) > 0 {
c.onDisconnectFunc(constants.KeepAliveDatabase, len(handle.databaseServers))
}

for range handle.kubernetesServers {
c.onDisconnectFunc(constants.KeepAliveKube)
if len(handle.kubernetesServers) > 0 {
c.onDisconnectFunc(constants.KeepAliveKube, len(handle.kubernetesServers))
}

clear(handle.appServers)
Expand Down Expand Up @@ -845,6 +846,7 @@ func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) e

if shouldRemove {
c.testEvent(appKeepAliveDel)
c.onDisconnectFunc(constants.KeepAliveApp, 1)
delete(handle.appServers, name)
}
} else {
Expand Down Expand Up @@ -887,6 +889,7 @@ func (c *Controller) keepAliveDatabaseServer(handle *upstreamHandle, now time.Ti

if shouldRemove {
c.testEvent(dbKeepAliveDel)
c.onDisconnectFunc(constants.KeepAliveDatabase, 1)
delete(handle.databaseServers, name)
}
} else {
Expand Down Expand Up @@ -929,6 +932,7 @@ func (c *Controller) keepAliveKubernetesServer(handle *upstreamHandle, now time.

if shouldRemove {
c.testEvent(kubeKeepAliveDel)
c.onDisconnectFunc(constants.KeepAliveKube, 1)
delete(handle.kubernetesServers, name)
}
} else {
Expand Down
59 changes: 58 additions & 1 deletion lib/inventory/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,14 @@ func TestSSHServerBasics(t *testing.T) {
expectAddr: wantAddr,
}

rc := &resourceCounter{}
controller := NewController(
auth,
usagereporter.DiscardUsageReporter{},
withServerKeepAlive(time.Millisecond*200),
withTestEventsChannel(events),
WithOnConnect(rc.onConnect),
WithOnDisconnect(rc.onDisconnect),
)
defer controller.Close()

Expand Down Expand Up @@ -314,6 +317,9 @@ func TestSSHServerBasics(t *testing.T) {
// here).
require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count())

// verify that metrics have been updated correctly
require.Zero(t, 0, rc.count())

// verify that the peer address of the control stream was used to override
// zero-value IPs for heartbeats.
auth.mu.Lock()
Expand All @@ -337,11 +343,14 @@ func TestAppServerBasics(t *testing.T) {

auth := &fakeAuth{}

rc := &resourceCounter{}
controller := NewController(
auth,
usagereporter.DiscardUsageReporter{},
withServerKeepAlive(time.Millisecond*200),
withTestEventsChannel(events),
WithOnConnect(rc.onConnect),
WithOnDisconnect(rc.onDisconnect),
)
defer controller.Close()

Expand Down Expand Up @@ -532,6 +541,9 @@ func TestAppServerBasics(t *testing.T) {
// always *before* closure is propagated to downstream handle, hence being safe to load
// here).
require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count())

// verify that metrics have been updated correctly
require.Zero(t, rc.count())
}

// TestDatabaseServerBasics verifies basic expected behaviors for a single control stream heartbeating
Expand All @@ -549,11 +561,14 @@ func TestDatabaseServerBasics(t *testing.T) {

auth := &fakeAuth{}

rc := &resourceCounter{}
controller := NewController(
auth,
usagereporter.DiscardUsageReporter{},
withServerKeepAlive(time.Millisecond*200),
withTestEventsChannel(events),
WithOnConnect(rc.onConnect),
WithOnDisconnect(rc.onDisconnect),
)
defer controller.Close()

Expand Down Expand Up @@ -745,6 +760,9 @@ func TestDatabaseServerBasics(t *testing.T) {
// always *before* closure is propagated to downstream handle, hence being safe to load
// here).
require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count())

// verify that metrics have been updated correctly
require.Zero(t, rc.count())
}

// TestInstanceHeartbeat verifies basic expected behaviors for instance heartbeat.
Expand Down Expand Up @@ -1154,11 +1172,14 @@ func TestKubernetesServerBasics(t *testing.T) {

auth := &fakeAuth{}

rc := &resourceCounter{}
controller := NewController(
auth,
usagereporter.DiscardUsageReporter{},
withServerKeepAlive(time.Millisecond*200),
withTestEventsChannel(events),
WithOnConnect(rc.onConnect),
WithOnDisconnect(rc.onDisconnect),
)
defer controller.Close()

Expand Down Expand Up @@ -1354,10 +1375,12 @@ func TestKubernetesServerBasics(t *testing.T) {
// always *before* closure is propagated to downstream handle, hence being safe to load
// here).
require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count())

// verify that metrics have been updated correctly
require.Zero(t, rc.count())
}

func TestGetSender(t *testing.T) {

controller := NewController(
&fakeAuth{},
usagereporter.DiscardUsageReporter{},
Expand Down Expand Up @@ -1468,3 +1491,37 @@ func awaitEvents(t *testing.T, ch <-chan testEvent, opts ...eventOption) {
}
}
}

type resourceCounter struct {
mu sync.Mutex
c map[string]int
}

func (r *resourceCounter) onConnect(typ string) {
r.mu.Lock()
defer r.mu.Unlock()
if r.c == nil {
r.c = make(map[string]int)
}
r.c[typ]++
}

func (r *resourceCounter) onDisconnect(typ string, amount int) {
r.mu.Lock()
defer r.mu.Unlock()
if r.c == nil {
r.c = make(map[string]int)
}
r.c[typ] -= amount
}

func (r *resourceCounter) count() int {
r.mu.Lock()
defer r.mu.Unlock()

var count int
for _, v := range r.c {
count += v
}
return count
}

0 comments on commit 91bb17a

Please sign in to comment.