From aa26424885e7c34ce1392a5e7db27b94a916cc6c Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 24 May 2024 19:27:56 +0200 Subject: [PATCH] Add an optional limit for concurrent `CreateAuditStream` operations (#41957) * Add unstable envvar to limit in-flight CreateAuditStream RPCs * Test the concurrency limit on CreateAuditStream * Log on parsing errors * Warn on limit 0, for good measure --- lib/auth/grpcserver.go | 60 +++++++++++++++++++++++++++++++++++++ lib/auth/grpcserver_test.go | 44 +++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 51fdbe5a659da..927c658f18820 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -25,6 +25,8 @@ import ( "fmt" "io" "net" + "os" + "strconv" "time" "github.com/coreos/go-semver/semver" @@ -137,6 +139,24 @@ var ( }, []string{teleport.TagType}, ) + + createAuditStreamAcceptedTotalMetric = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Name: "unstable_createauditstream_accepted_total", + Help: "CreateAuditStream RPCs accepted by the concurrency limiter", + }) + + createAuditStreamRejectedTotalMetric = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Name: "unstable_createauditstream_rejected_total", + Help: "CreateAuditStream RPCs rejected by the concurrency limiter", + }) + + createAuditStreamLimitMetric = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: teleport.MetricNamespace, + Name: "unstable_createauditstream_limit", + Help: "Configured limit of in-flight CreateAuditStream RPCs", + }) ) // GRPCServer is gRPC Auth Server API @@ -164,6 +184,11 @@ type GRPCServer struct { // TraceServiceServer exposes the exporter server so that the auth server may // collect and forward spans collectortracepb.TraceServiceServer + + // createAuditStreamSemaphore, if not nil, is used to limit the amount of + // in-flight CreateAuditStream RPCs, by sending a value in at the beginning + // of the RPC and pulling one out before returning. + createAuditStreamSemaphore chan struct{} } // Export forwards OTLP traces to the upstream collector configured in the tracing service. This allows for @@ -266,6 +291,21 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre return trace.Wrap(err) } + if sem := g.createAuditStreamSemaphore; sem != nil { + select { + case sem <- struct{}{}: + createAuditStreamAcceptedTotalMetric.Inc() + defer func() { <-sem }() + default: + createAuditStreamRejectedTotalMetric.Inc() + // [trace.ConnectionProblemError] is rendered with a gRPC + // "unavailable" error code, which is the correct error if the + // client can just back off and retry with no further changes to the + // request + return trace.ConnectionProblem(nil, "too many concurrent CreateAuditStream operations, try again later") + } + } + var eventStream apievents.Stream var sessionID session.ID g.Debugf("CreateAuditStream connection from %v.", auth.User.GetName()) @@ -5147,6 +5187,26 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { presenceService: presenceService, } + if en := os.Getenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT"); en != "" { + inflightLimit, err := strconv.ParseInt(en, 10, 64) + if err != nil { + log.Error("Failed to parse the TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT envvar, limit will not be enforced.") + inflightLimit = -1 + } + if inflightLimit == 0 { + log.Warn("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT is set to 0, no CreateAuditStream RPCs will be allowed.") + } + metrics.RegisterPrometheusCollectors( + createAuditStreamAcceptedTotalMetric, + createAuditStreamRejectedTotalMetric, + createAuditStreamLimitMetric, + ) + createAuditStreamLimitMetric.Set(float64(inflightLimit)) + if inflightLimit >= 0 { + authServer.createAuditStreamSemaphore = make(chan struct{}, inflightLimit) + } + } + authpb.RegisterAuthServiceServer(server, authServer) collectortracepb.RegisterTraceServiceServer(server, authServer) auditlogpb.RegisterAuditLogServiceServer(server, authServer) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 88aa325e6152e..58a10455c1c56 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -40,6 +40,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" + prom_client_model "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" otlpcommonv1 "go.opentelemetry.io/proto/otlp/common/v1" @@ -71,6 +72,7 @@ import ( dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/tlsca" ) @@ -4469,3 +4471,45 @@ func TestGetAccessGraphConfig(t *testing.T) { }) } } + +func TestCreateAuditStreamLimit(t *testing.T) { + const N = 5 + t.Setenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT", fmt.Sprintf("%d", N)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server := newTestTLSServer(t) + clt, err := server.NewClient(TestServerID(types.RoleNode, uuid.NewString())) + require.NoError(t, err) + + // HACK(espadolini): we're piggybacking on the prometheus counter which + // can't change while this test is running (we set an envvar, so we can't be + // running in parallel with other tests) but it's still pretty awful, and + // it'd be much better to actually check that the streams were accepted by + // the server; unfortunately, the CreateAuditStream stream doesn't actually + // send anything back unless there's a real upload going on, and the test + // server uses a discard emitter which never ends up sending anything + getAcceptedTotal := func() int { + var m prom_client_model.Metric + require.NoError(t, createAuditStreamAcceptedTotalMetric.Write(&m)) + return int(m.Counter.GetValue()) + } + currentAcceptedTotal := getAcceptedTotal() + + for i := 0; i < N; i++ { + stream, err := clt.CreateAuditStream(ctx, session.NewID()) + require.NoError(t, err) + t.Cleanup(func() { stream.Close(ctx) }) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + assert.EqualValues(t, currentAcceptedTotal+N, getAcceptedTotal()) + }, time.Second, 100*time.Millisecond) + + ac := proto.NewAuthServiceClient(clt.APIClient.GetConnection()) + stream, err := ac.CreateAuditStream(ctx) + require.NoError(t, err) + _, err = stream.Recv() + require.ErrorAs(t, err, new(*trace.ConnectionProblemError)) +}