Skip to content

Commit

Permalink
[v14] Connector circuit breaker metrics (#40755)
Browse files Browse the repository at this point in the history
* Fix time.After in joinserver

* Log failed cluster joins

* Instrument teleport instance connector breakers

* Address comments, fix nil panic
  • Loading branch information
espadolini authored Apr 22, 2024
1 parent 0d5dbfa commit 4f2fdfd
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 42 deletions.
45 changes: 33 additions & 12 deletions api/breaker/breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,18 @@ type Config struct {
// StateStandby to StateTripped. This is required to be supplied, failure to do so will result in an error
// creating the CircuitBreaker.
Trip TripFn
// OnTripped will be called when the CircuitBreaker enters the StateTripped state
// OnTripped will be called when the CircuitBreaker enters the StateTripped
// state; this callback is called while holding a lock, so it should return
// quickly.
OnTripped func()
// OnStandby will be called when the CircuitBreaker returns to the StateStandby state
// OnStandby will be called when the CircuitBreaker returns to the
// StateStandby state; this callback is called while holding a lock, so it
// should return quickly.
OnStandBy func()
// OnExecute will be called once for each execution, and given the result
// and the current state of the breaker state; this callback is called while
// holding a lock, so it should return quickly.
OnExecute func(success bool, state State)
// IsSuccessful is used by the CircuitBreaker to determine if the executed function was successful or not
IsSuccessful func(v interface{}, err error) bool
// Logger is the logger
Expand All @@ -139,6 +147,12 @@ type Config struct {
TrippedErrorMessage string
}

// Clone returns a clone of the Config.
func (c *Config) Clone() Config {
// the current Config can just be copied without issues
return *c
}

// TripFn determines if the CircuitBreaker should be tripped based
// on the state of the provided Metrics. A return value of true will
// cause the CircuitBreaker to transition into the StateTripped state
Expand Down Expand Up @@ -256,6 +270,10 @@ func (c *Config) CheckAndSetDefaults() error {
c.OnStandBy = func() {}
}

if c.OnExecute == nil {
c.OnExecute = func(bool, State) {}
}

if c.IsSuccessful == nil {
c.IsSuccessful = NonNilErrorIsSuccess
}
Expand Down Expand Up @@ -332,8 +350,9 @@ func (c *CircuitBreaker) beforeExecution() (uint64, error) {

generation, state := c.currentState(now)

switch {
case state == StateTripped:
if state == StateTripped {
c.cfg.OnExecute(false, StateTripped)

if c.cfg.TrippedErrorMessage != "" {
return generation, trace.ConnectionProblem(nil, c.cfg.TrippedErrorMessage)
}
Expand All @@ -359,21 +378,21 @@ func (c *CircuitBreaker) afterExecution(prior uint64, v interface{}, err error)
}

if c.cfg.IsSuccessful(v, err) {
c.cfg.Logger.Debugf("successful execution, %s", c.metrics.String())
c.success(state, now)
c.successLocked(state, now)
} else {
c.cfg.Logger.Debugf("failed execution, %s", c.metrics.String())
c.failure(state, now)
c.failureLocked(state, now)
}
}

// success tallies a successful execution and migrates to StateStandby
// successLocked tallies a successful execution and migrates to StateStandby
// if in another state and criteria has been met to transition
func (c *CircuitBreaker) success(state State, t time.Time) {
func (c *CircuitBreaker) successLocked(state State, t time.Time) {
switch state {
case StateStandby:
c.cfg.OnExecute(true, StateStandby)
c.metrics.success()
case StateRecovering:
c.cfg.OnExecute(true, StateRecovering)
c.metrics.success()
if c.metrics.ConsecutiveSuccesses >= c.cfg.RecoveryLimit {
c.setState(StateStandby, t)
Expand All @@ -382,17 +401,19 @@ func (c *CircuitBreaker) success(state State, t time.Time) {
}
}

// failure tallies a failed execution and migrate to StateTripped
// failureLocked tallies a failed execution and migrate to StateTripped
// if in another state and criteria has been met to transition
func (c *CircuitBreaker) failure(state State, t time.Time) {
func (c *CircuitBreaker) failureLocked(state State, t time.Time) {
c.metrics.failure()

switch state {
case StateRecovering:
c.cfg.OnExecute(false, StateRecovering)
if c.cfg.Recover(c.metrics) {
c.setState(StateTripped, t)
}
case StateStandby:
c.cfg.OnExecute(false, StateStandby)
if c.cfg.Trip(c.metrics) {
c.setState(StateTripped, t)
go c.cfg.OnTripped()
Expand Down
4 changes: 2 additions & 2 deletions api/breaker/breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestCircuitBreaker_success(t *testing.T) {
cb.state = tt.initialState

generation, state := cb.currentState(clock.Now())
cb.success(tt.successState, clock.Now())
cb.successLocked(tt.successState, clock.Now())
require.Equal(t, tt.expectedState, cb.state)
if tt.expectedState != state {
require.NotEqual(t, generation, cb.generation)
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestCircuitBreaker_failure(t *testing.T) {
cb.state = tt.initialState

generation, state := cb.currentState(clock.Now())
cb.failure(tt.failureState, clock.Now())
cb.failureLocked(tt.failureState, clock.Now())
require.Equal(t, tt.expectedState, cb.state)
if tt.expectedState != state {
require.NotEqual(t, generation, cb.generation)
Expand Down
24 changes: 21 additions & 3 deletions lib/auth/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/peer"

"github.com/gravitational/teleport/api/client/proto"
Expand Down Expand Up @@ -122,14 +123,31 @@ func setRemoteAddrFromContext(ctx context.Context, req *types.RegisterUsingToken
//
// If the token includes a specific join method, the rules for that join method
// will be checked.
func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
log.Infof("Node %q [%v] is trying to join with role: %v.", req.NodeName, req.HostID, req.Role)
func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (_ *proto.Certs, err error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}

method := a.tokenJoinMethod(ctx, req.Token)
defer func() {
if err == nil {
return
}
level := logrus.WarnLevel
if trace.IsAccessDenied(err) {
level = logrus.DebugLevel
}
log.WithFields(logrus.Fields{
"node_name": req.NodeName,
"host_id": req.HostID,
"role": req.Role,
"method": method,
logrus.ErrorKey: err,
}).Log(level, "Agent has failed to join the cluster.")
}()

var joinAttributeSrc joinAttributeSourcer
switch method := a.tokenJoinMethod(ctx, req.Token); method {
switch method {
case types.JoinMethodEC2:
if err := a.checkEC2JoinRequest(ctx, req); err != nil {
return nil, trace.Wrap(err)
Expand Down
22 changes: 21 additions & 1 deletion lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
Expand Down Expand Up @@ -337,7 +338,7 @@ func withFips(fips bool) iamRegisterOption {
// The caller must provide a ChallengeResponseFunc which returns a
// *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request
// including the challenge as a signed header.
func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc, opts ...iamRegisterOption) (*proto.Certs, error) {
func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc, opts ...iamRegisterOption) (_ *proto.Certs, err error) {
cfg := defaultIAMRegisterConfig(a.fips)
for _, opt := range opts {
opt(cfg)
Expand All @@ -357,11 +358,30 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c
return nil, trace.Wrap(err)
}

var method types.JoinMethod = "unknown"
defer func() {
if err == nil {
return
}
level := logrus.WarnLevel
if trace.IsAccessDenied(err) {
level = logrus.DebugLevel
}
log.WithFields(logrus.Fields{
"node_name": req.RegisterUsingTokenRequest.NodeName,
"host_id": req.RegisterUsingTokenRequest.HostID,
"role": req.RegisterUsingTokenRequest.Role,
"method": method,
logrus.ErrorKey: err,
}).Log(level, "Agent has failed to join the cluster.")
}()

// perform common token checks
provisionToken, err := a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest)
if err != nil {
return nil, trace.Wrap(err)
}
method = provisionToken.GetJoinMethod()

// check that the GetCallerIdentity request is valid and matches the token
if err := a.checkIAMRequest(ctx, challenge, req, cfg); err != nil {
Expand Down
32 changes: 14 additions & 18 deletions lib/joinserver/joinserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,35 @@ func NewJoinServiceGRPCServer(joinServiceClient joinServiceClient) *JoinServiceG
// sts:GetCallerIdentity request with the challenge string. Finally, the signed
// cluster certs are sent on the server stream.
func (s *JoinServiceGRPCServer) RegisterUsingIAMMethod(srv proto.JoinService_RegisterUsingIAMMethodServer) error {
ctx := srv.Context()

// Enforce a timeout on the entire RPC so that misbehaving clients cannot
// hold connections open indefinitely.
timeout := s.clock.After(iamJoinRequestTimeout)
timeout := s.clock.NewTimer(iamJoinRequestTimeout)
defer timeout.Stop()

// The only way to cancel a blocked Send or Recv on the server side without
// adding an interceptor to the entire gRPC service is to return from the
// handler https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
errCh := make(chan error, 1)
go func() {
errCh <- s.registerUsingIAMMethod(ctx, srv)
errCh <- s.registerUsingIAMMethod(srv)
}()
select {
case err := <-errCh:
// Completed before the deadline, return the error (may be nil).
return trace.Wrap(err)
case <-timeout:
case <-timeout.Chan():
nodeAddr := ""
if peerInfo, ok := peer.FromContext(ctx); ok {
if peerInfo, ok := peer.FromContext(srv.Context()); ok {
nodeAddr = peerInfo.Addr.String()
}
logrus.Warnf("IAM join attempt timed out, node at (%s) is misbehaving or did not close the connection after encountering an error.", nodeAddr)
// Returning here should cancel any blocked Send or Recv operations.
return trace.LimitExceeded("RegisterUsingIAMMethod timed out after %s, terminating the stream on the server", iamJoinRequestTimeout)
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

func (s *JoinServiceGRPCServer) registerUsingIAMMethod(ctx context.Context, srv proto.JoinService_RegisterUsingIAMMethodServer) error {
func (s *JoinServiceGRPCServer) registerUsingIAMMethod(srv proto.JoinService_RegisterUsingIAMMethodServer) error {
ctx := srv.Context()
// Call RegisterUsingIAMMethod with a callback to get the challenge response
// from the gRPC client.
certs, err := s.joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
Expand Down Expand Up @@ -144,33 +142,30 @@ func (s *JoinServiceGRPCServer) registerUsingIAMMethod(ctx context.Context, srv
// attested data document with the challenge string. Finally, the signed
// cluster certs are sent on the server stream.
func (s *JoinServiceGRPCServer) RegisterUsingAzureMethod(srv proto.JoinService_RegisterUsingAzureMethodServer) error {
ctx := srv.Context()

// Enforce a timeout on the entire RPC so that misbehaving clients cannot
// hold connections open indefinitely.
timeout := s.clock.After(azureJoinRequestTimeout)
timeout := s.clock.NewTimer(azureJoinRequestTimeout)
defer timeout.Stop()

// The only way to cancel a blocked Send or Recv on the server side without
// adding an interceptor to the entire gRPC service is to return from the
// handler https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
errCh := make(chan error, 1)
go func() {
errCh <- s.registerUsingAzureMethod(ctx, srv)
errCh <- s.registerUsingAzureMethod(srv)
}()
select {
case err := <-errCh:
// Completed before the deadline, return the error (may be nil).
return trace.Wrap(err)
case <-timeout:
case <-timeout.Chan():
nodeAddr := ""
if peerInfo, ok := peer.FromContext(ctx); ok {
if peerInfo, ok := peer.FromContext(srv.Context()); ok {
nodeAddr = peerInfo.Addr.String()
}
logrus.Warnf("Azure join attempt timed out, node at (%s) is misbehaving or did not close the connection after encountering an error.", nodeAddr)
// Returning here should cancel any blocked Send or Recv operations.
return trace.LimitExceeded("RegisterUsingAzureMethod timed out after %s, terminating the stream on the server", azureJoinRequestTimeout)
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

Expand All @@ -196,7 +191,8 @@ func setClientRemoteAddr(ctx context.Context, req *types.RegisterUsingTokenReque
return nil
}

func (s *JoinServiceGRPCServer) registerUsingAzureMethod(ctx context.Context, srv proto.JoinService_RegisterUsingAzureMethodServer) error {
func (s *JoinServiceGRPCServer) registerUsingAzureMethod(srv proto.JoinService_RegisterUsingAzureMethodServer) error {
ctx := srv.Context()
certs, err := s.joinServiceClient.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
err := srv.Send(&proto.RegisterUsingAzureMethodResponse{
Challenge: challenge,
Expand Down
56 changes: 56 additions & 0 deletions lib/service/breaker/breaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package breaker

import (
"strconv"
"sync"

"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/types"
)

var connectorExecutions = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Subsystem: "breaker",
Name: "connector_executions_total",
Help: "Client requests per system role, state of the breaker and success as interpreted by the breaker.",
}, []string{"role", "state", "success"})

var registerOnce sync.Once

func ensureRegistered() {
registerOnce.Do(func() {
prometheus.MustRegister(connectorExecutions)
})
}

// InstrumentBreakerForConnector returns a copy of a [breaker.Config] that
// counts client "executions" (i.e. requests or streams) that go through the
// breaker, attributing the count to the given system role.
func InstrumentBreakerForConnector(role types.SystemRole, cfg breaker.Config) breaker.Config {
ensureRegistered()

cfg = cfg.Clone()
cfg.OnExecute = func(success bool, state breaker.State) {
connectorExecutions.WithLabelValues(role.String(), state.String(), strconv.FormatBool(success)).Inc()
}
return cfg
}
Loading

0 comments on commit 4f2fdfd

Please sign in to comment.