Skip to content

Commit

Permalink
Add an optional limit for concurrent CreateAuditStream operations (#…
Browse files Browse the repository at this point in the history
…41957)

* Add unstable envvar to limit in-flight CreateAuditStream RPCs

* Test the concurrency limit on CreateAuditStream

* Log on parsing errors

* Warn on limit 0, for good measure
  • Loading branch information
espadolini authored May 24, 2024
1 parent 2f15817 commit aa26424
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
60 changes: 60 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"fmt"
"io"
"net"
"os"
"strconv"
"time"

"github.com/coreos/go-semver/semver"
Expand Down Expand Up @@ -137,6 +139,24 @@ var (
},
[]string{teleport.TagType},
)

createAuditStreamAcceptedTotalMetric = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Name: "unstable_createauditstream_accepted_total",
Help: "CreateAuditStream RPCs accepted by the concurrency limiter",
})

createAuditStreamRejectedTotalMetric = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Name: "unstable_createauditstream_rejected_total",
Help: "CreateAuditStream RPCs rejected by the concurrency limiter",
})

createAuditStreamLimitMetric = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: teleport.MetricNamespace,
Name: "unstable_createauditstream_limit",
Help: "Configured limit of in-flight CreateAuditStream RPCs",
})
)

// GRPCServer is gRPC Auth Server API
Expand Down Expand Up @@ -164,6 +184,11 @@ type GRPCServer struct {
// TraceServiceServer exposes the exporter server so that the auth server may
// collect and forward spans
collectortracepb.TraceServiceServer

// createAuditStreamSemaphore, if not nil, is used to limit the amount of
// in-flight CreateAuditStream RPCs, by sending a value in at the beginning
// of the RPC and pulling one out before returning.
createAuditStreamSemaphore chan struct{}
}

// Export forwards OTLP traces to the upstream collector configured in the tracing service. This allows for
Expand Down Expand Up @@ -266,6 +291,21 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
return trace.Wrap(err)
}

if sem := g.createAuditStreamSemaphore; sem != nil {
select {
case sem <- struct{}{}:
createAuditStreamAcceptedTotalMetric.Inc()
defer func() { <-sem }()
default:
createAuditStreamRejectedTotalMetric.Inc()
// [trace.ConnectionProblemError] is rendered with a gRPC
// "unavailable" error code, which is the correct error if the
// client can just back off and retry with no further changes to the
// request
return trace.ConnectionProblem(nil, "too many concurrent CreateAuditStream operations, try again later")
}
}

var eventStream apievents.Stream
var sessionID session.ID
g.Debugf("CreateAuditStream connection from %v.", auth.User.GetName())
Expand Down Expand Up @@ -5147,6 +5187,26 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
presenceService: presenceService,
}

if en := os.Getenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT"); en != "" {
inflightLimit, err := strconv.ParseInt(en, 10, 64)
if err != nil {
log.Error("Failed to parse the TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT envvar, limit will not be enforced.")
inflightLimit = -1
}
if inflightLimit == 0 {
log.Warn("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT is set to 0, no CreateAuditStream RPCs will be allowed.")
}
metrics.RegisterPrometheusCollectors(
createAuditStreamAcceptedTotalMetric,
createAuditStreamRejectedTotalMetric,
createAuditStreamLimitMetric,
)
createAuditStreamLimitMetric.Set(float64(inflightLimit))
if inflightLimit >= 0 {
authServer.createAuditStreamSemaphore = make(chan struct{}, inflightLimit)
}
}

authpb.RegisterAuthServiceServer(server, authServer)
collectortracepb.RegisterTraceServiceServer(server, authServer)
auditlogpb.RegisterAuditLogServiceServer(server, authServer)
Expand Down
44 changes: 44 additions & 0 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
prom_client_model "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
otlpcommonv1 "go.opentelemetry.io/proto/otlp/common/v1"
Expand Down Expand Up @@ -71,6 +72,7 @@ import (
dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand Down Expand Up @@ -4469,3 +4471,45 @@ func TestGetAccessGraphConfig(t *testing.T) {
})
}
}

func TestCreateAuditStreamLimit(t *testing.T) {
const N = 5
t.Setenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT", fmt.Sprintf("%d", N))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

server := newTestTLSServer(t)
clt, err := server.NewClient(TestServerID(types.RoleNode, uuid.NewString()))
require.NoError(t, err)

// HACK(espadolini): we're piggybacking on the prometheus counter which
// can't change while this test is running (we set an envvar, so we can't be
// running in parallel with other tests) but it's still pretty awful, and
// it'd be much better to actually check that the streams were accepted by
// the server; unfortunately, the CreateAuditStream stream doesn't actually
// send anything back unless there's a real upload going on, and the test
// server uses a discard emitter which never ends up sending anything
getAcceptedTotal := func() int {
var m prom_client_model.Metric
require.NoError(t, createAuditStreamAcceptedTotalMetric.Write(&m))
return int(m.Counter.GetValue())
}
currentAcceptedTotal := getAcceptedTotal()

for i := 0; i < N; i++ {
stream, err := clt.CreateAuditStream(ctx, session.NewID())
require.NoError(t, err)
t.Cleanup(func() { stream.Close(ctx) })
}

require.EventuallyWithT(t, func(t *assert.CollectT) {
assert.EqualValues(t, currentAcceptedTotal+N, getAcceptedTotal())
}, time.Second, 100*time.Millisecond)

ac := proto.NewAuthServiceClient(clt.APIClient.GetConnection())
stream, err := ac.CreateAuditStream(ctx)
require.NoError(t, err)
_, err = stream.Recv()
require.ErrorAs(t, err, new(*trace.ConnectionProblemError))
}

0 comments on commit aa26424

Please sign in to comment.