Skip to content

Commit

Permalink
Update HeartbeatV2Config.GetResource to be fallible (#46550) (#46573)
Browse files Browse the repository at this point in the history
Up until now, the resources converted to srv.HeartbbeatV2 have
not been fallible. However, in order to support KubernetesServer
resources, which rely on external factors to populate the resource,
the GetResource function needs to be updated to return an error.
  • Loading branch information
rosstimothy authored Sep 13, 2024
1 parent 4919fe7 commit 429e768
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 48 deletions.
34 changes: 15 additions & 19 deletions lib/srv/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,37 +494,33 @@ func (s *Server) stopHeartbeat(name string) error {

// getServerInfoFunc returns function that the heartbeater uses to report the
// provided application to the auth server.
func (s *Server) getServerInfoFunc(app types.Application) func() *types.AppServerV3 {
return func() *types.AppServerV3 {
func (s *Server) getServerInfoFunc(app types.Application) func(context.Context) (*types.AppServerV3, error) {
return func(context.Context) (*types.AppServerV3, error) {
return s.getServerInfo(app)
}
}

// getServerInfo returns up-to-date app resource.
func (s *Server) getServerInfo(app types.Application) *types.AppServerV3 {
func (s *Server) getServerInfo(app types.Application) (*types.AppServerV3, error) {
// Make sure to return a new object, because it gets cached by
// heartbeat and will always compare as equal otherwise.
s.mu.RLock()
copy := s.appWithUpdatedLabelsLocked(app)
s.mu.RUnlock()
expires := s.c.Clock.Now().UTC().Add(apidefaults.ServerAnnounceTTL)
server, err := types.NewAppServerV3(types.Metadata{
Name: copy.GetName(),
Expires: &expires,
}, types.AppServerSpecV3{
Version: teleport.Version,
Hostname: s.c.Hostname,
HostID: s.c.HostID,
Rotation: s.getRotationState(),
App: copy,
ProxyIDs: s.c.ConnectedProxyGetter.GetProxyIDs(),
})

return &types.AppServerV3{
Kind: types.KindAppServer,
Version: types.V3,
Metadata: types.Metadata{
Name: copy.GetName(),
Expires: &expires,
},
Spec: types.AppServerSpecV3{
Version: teleport.Version,
Hostname: s.c.Hostname,
HostID: s.c.HostID,
Rotation: s.getRotationState(),
App: copy,
ProxyIDs: s.c.ConnectedProxyGetter.GetProxyIDs(),
},
}
return server, trace.Wrap(err)
}

// getRotationState is a helper to return this server's CA rotation state.
Expand Down
4 changes: 3 additions & 1 deletion lib/srv/app/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,10 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite {
for _, app := range apps {
select {
case sender := <-inventoryHandle.Sender():
appServer, err := s.appServer.getServerInfo(app)
require.NoError(t, err)
require.NoError(t, sender.Send(s.closeContext, proto.InventoryHeartbeat{
AppServer: s.appServer.getServerInfo(app),
AppServer: appServer,
}))
case <-time.After(20 * time.Second):
t.Fatal("timed out waiting for inventory handle sender")
Expand Down
70 changes: 51 additions & 19 deletions lib/srv/heartbeatv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/gravitational/teleport/api/client/proto"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/defaults"
Expand All @@ -45,7 +46,7 @@ type HeartbeatV2Config[T any] struct {
// InventoryHandle is used to send heartbeats.
InventoryHandle inventory.DownstreamHandle
// GetResource gets the latest item to heartbeat.
GetResource func() T
GetResource func(context.Context) (T, error)

// -- below values are all optional

Expand Down Expand Up @@ -86,8 +87,11 @@ func NewSSHServerHeartbeat(cfg HeartbeatV2Config[*types.ServerV2]) (*HeartbeatV2
getMetadata: metadata.Get,
announcer: cfg.Announcer,
}
inner.getServer = func(ctx context.Context) *types.ServerV2 {
server := cfg.GetResource()
inner.getServer = func(ctx context.Context) (*types.ServerV2, error) {
server, err := cfg.GetResource(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

doneCtx, cancel := context.WithCancel(ctx)
cancel() // not a typo
Expand All @@ -97,7 +101,7 @@ func NewSSHServerHeartbeat(cfg HeartbeatV2Config[*types.ServerV2]) (*HeartbeatV2
server.SetCloudMetadata(meta.CloudMetadata)
}

return server
return server, nil
}

return newHeartbeatV2(cfg.InventoryHandle, inner, heartbeatV2Config{
Expand All @@ -116,7 +120,7 @@ func NewAppServerHeartbeat(cfg HeartbeatV2Config[*types.AppServerV3]) (*Heartbea
}

inner := &appServerHeartbeatV2{
getServer: func(ctx context.Context) *types.AppServerV3 { return cfg.GetResource() },
getServer: cfg.GetResource,
announcer: cfg.Announcer,
}

Expand Down Expand Up @@ -505,7 +509,7 @@ type metadataGetter func(ctx context.Context) (*metadata.Metadata, error)

// sshServerHeartbeatV2 is the heartbeatV2 implementation for ssh servers.
type sshServerHeartbeatV2 struct {
getServer func(ctx context.Context) *types.ServerV2
getServer func(ctx context.Context) (*types.ServerV2, error)
getMetadata metadataGetter
announcer authclient.Announcer
prev *types.ServerV2
Expand All @@ -515,7 +519,13 @@ func (h *sshServerHeartbeatV2) Poll(ctx context.Context) (changed bool) {
if h.prev == nil {
return true
}
return services.CompareServers(h.getServer(ctx), h.prev) == services.Different

server, err := h.getServer(ctx)
if err != nil {
return false
}

return services.CompareServers(server, h.prev) == services.Different
}

func (h *sshServerHeartbeatV2) SupportsFallback() bool {
Expand All @@ -526,32 +536,39 @@ func (h *sshServerHeartbeatV2) FallbackAnnounce(ctx context.Context) (ok bool) {
if h.announcer == nil {
return false
}
server := h.getServer(ctx)
_, err := h.announcer.UpsertNode(ctx, server)
server, err := h.getServer(ctx)
if err != nil {
log.Warnf("Failed to perform fallback heartbeat for ssh server: %v", err)
return false
}

if _, err := h.announcer.UpsertNode(ctx, server); err != nil {
log.Warnf("Failed to perform fallback heartbeat for ssh server: %v", err)
return false
}

h.prev = server
return true
}

func (h *sshServerHeartbeatV2) Announce(ctx context.Context, sender inventory.DownstreamSender) (ok bool) {
server := h.getServer(ctx)
err := sender.Send(ctx, proto.InventoryHeartbeat{
SSHServer: h.getServer(ctx),
})
server, err := h.getServer(ctx)
if err != nil {
log.Warnf("Failed to perform inventory heartbeat for ssh server: %v", err)
return false
}

if err := sender.Send(ctx, proto.InventoryHeartbeat{SSHServer: apiutils.CloneProtoMsg(server)}); err != nil {
log.Warnf("Failed to perform inventory heartbeat for ssh server: %v", err)
return false
}
h.prev = server
return true
}

// appServerHeartbeatV2 is the heartbeatV2 implementation for app servers.
type appServerHeartbeatV2 struct {
getServer func(ctx context.Context) *types.AppServerV3
getServer func(ctx context.Context) (*types.AppServerV3, error)
announcer authclient.Announcer
prev *types.AppServerV3
}
Expand All @@ -560,7 +577,13 @@ func (h *appServerHeartbeatV2) Poll(ctx context.Context) (changed bool) {
if h.prev == nil {
return true
}
return services.CompareServers(h.getServer(ctx), h.prev) == services.Different

server, err := h.getServer(ctx)
if err != nil {
return false
}

return services.CompareServers(server, h.prev) == services.Different
}

func (h *appServerHeartbeatV2) SupportsFallback() bool {
Expand All @@ -571,9 +594,13 @@ func (h *appServerHeartbeatV2) FallbackAnnounce(ctx context.Context) (ok bool) {
if h.announcer == nil {
return false
}
server := h.getServer(ctx)
_, err := h.announcer.UpsertApplicationServer(ctx, server)
server, err := h.getServer(ctx)
if err != nil {
log.Warnf("Failed to perform fallback heartbeat for app server: %v", err)
return false
}

if _, err := h.announcer.UpsertApplicationServer(ctx, server); err != nil {
if !errors.Is(err, context.Canceled) && status.Code(err) != codes.Canceled {
log.Warnf("Failed to perform fallback heartbeat for app server: %v", err)
}
Expand Down Expand Up @@ -605,8 +632,13 @@ func (h *appServerHeartbeatV2) Announce(ctx context.Context, sender inventory.Do
return h.FallbackAnnounce(ctx)
}

server := h.getServer(ctx)
if err := sender.Send(ctx, proto.InventoryHeartbeat{AppServer: h.getServer(ctx)}); err != nil {
server, err := h.getServer(ctx)
if err != nil {
log.Warnf("Failed to perform inventory heartbeat for app server: %v", err)
return false
}

if err := sender.Send(ctx, proto.InventoryHeartbeat{AppServer: apiutils.CloneProtoMsg(server)}); err != nil {
if !errors.Is(err, context.Canceled) && status.Code(err) != codes.Canceled {
log.Warnf("Failed to perform inventory heartbeat for app server: %v", err)
}
Expand Down
15 changes: 10 additions & 5 deletions lib/srv/heartbeatv2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,10 @@ func TestNewHeartbeatFetchMetadata(t *testing.T) {

heartbeat, err := NewSSHServerHeartbeat(HeartbeatV2Config[*types.ServerV2]{
InventoryHandle: &fakeDownstreamHandle{},
GetResource: func() *types.ServerV2 {
GetResource: func(context.Context) (*types.ServerV2, error) {
return &types.ServerV2{
Spec: types.ServerSpecV2{},
}
}, nil
},
})
require.NoError(t, err)
Expand All @@ -525,18 +525,23 @@ func TestNewHeartbeatFetchMetadata(t *testing.T) {
inner.getMetadata = metadataGetter

// Metadata won't be set before metadata getter returns.
server := inner.getServer(ctx)
server, err := inner.getServer(ctx)
require.NoError(t, err)
assert.Nil(t, server.GetCloudMetadata(), "Metadata was set before background process returned")

// Metadata won't be set if the getter fails.
metaCh <- nil
time.Sleep(100 * time.Millisecond) // Wait for goroutines to complete
assert.Nil(t, inner.getServer(ctx).GetCloudMetadata(), "Metadata was set despite metadata getter failing")
server, err = inner.getServer(ctx)
require.NoError(t, err)
assert.Nil(t, server.GetCloudMetadata(), "Metadata was set despite metadata getter failing")

// getServer gets updated metadata value.
metaCh <- makeMetadata("foo")
time.Sleep(100 * time.Millisecond) // Wait for goroutines to complete
meta := inner.getServer(ctx).GetCloudMetadata()
server, err = inner.getServer(ctx)
require.NoError(t, err)
meta := server.GetCloudMetadata()
assert.NotNil(t, meta, "Heartbeat never got metadata")
assert.Equal(t, "foo", meta.AWS.InstanceID)
}
6 changes: 3 additions & 3 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ func (s *Server) getBasicInfo() *types.ServerV2 {
return srv
}

func (s *Server) getServerInfo() *types.ServerV2 {
func (s *Server) getServerInfo(context.Context) (*types.ServerV2, error) {
server := s.getBasicInfo()
if s.getRotation != nil {
rotation, err := s.getRotation(s.getRole())
Expand All @@ -1081,11 +1081,11 @@ func (s *Server) getServerInfo() *types.ServerV2 {

server.SetExpiry(s.clock.Now().UTC().Add(apidefaults.ServerAnnounceTTL))
server.SetPeerAddr(s.peerAddr)
return server
return server, nil
}

func (s *Server) getServerResource() (types.Resource, error) {
return s.getServerInfo(), nil
return s.getServerInfo(s.ctx)
}

// getDirectTCPIPForwarder sets up a connection-level subprocess that handles forwarding connections. Subsequent
Expand Down
4 changes: 3 additions & 1 deletion lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO
sshSrv.Wait()
})

_, err = testServer.Auth().UpsertNode(ctx, sshSrv.getServerInfo())
server, err := sshSrv.getServerInfo(ctx)
require.NoError(t, err)
_, err = testServer.Auth().UpsertNode(ctx, server)
require.NoError(t, err)

sshSrvAddress := sshSrv.Addr()
Expand Down

0 comments on commit 429e768

Please sign in to comment.