Skip to content

Commit

Permalink
Convert auth grpc server to use slog
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Dec 21, 2024
1 parent 8890e3e commit 8acd795
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 44 deletions.
2 changes: 1 addition & 1 deletion e
Submodule e updated from 7d3799 to 1ac55f
125 changes: 82 additions & 43 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"os"
"strconv"
Expand All @@ -34,7 +35,6 @@ import (
"github.com/gravitational/trace"
"github.com/gravitational/trace/trail"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
collectortracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -123,6 +123,7 @@ import (
"github.com/gravitational/teleport/lib/srv/server/installer"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

var (
Expand Down Expand Up @@ -179,7 +180,7 @@ var (
type GRPCServer struct {
authpb.UnimplementedAuthServiceServer
auditlogpb.UnimplementedAuditLogServiceServer
*logrus.Entry
logger *slog.Logger
APIConfig
server *grpc.Server

Expand Down Expand Up @@ -256,27 +257,30 @@ func (g *GRPCServer) SendKeepAlives(stream authpb.AuthService_SendKeepAlivesServ
}
keepAlive, err := stream.Recv()
if errors.Is(err, io.EOF) {
g.Logger.Debug("Connection closed.")
g.logger.DebugContext(stream.Context(), "Connection closed")
return nil
}
if err != nil {
g.Logger.Debugf("Failed to receive heartbeat: %v", err)
g.logger.DebugContext(stream.Context(), "Failed to receive heartbeat", "error", err)
return trace.Wrap(err)
}
err = auth.KeepAliveServer(stream.Context(), *keepAlive)
if err != nil {
return trace.Wrap(err)
}
if firstIteration {
g.Logger.Debugf("Got %s heartbeat connection from %v.", keepAlive.GetType(), auth.User.GetName())
g.logger.DebugContext(stream.Context(), "Got heartbeat connection",
"heartbeat_type", keepAlive.GetType(),
"identity", auth.User.GetName(),
)
heartbeatConnectionsReceived.Inc()

metric, ok := connectedResourceGauges[keepAlive.GetType()]
if ok {
metric.Inc()
defer metric.Dec()
} else {
g.Logger.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", keepAlive.GetType())
g.logger.WarnContext(stream.Context(), "missing connected resources gauge for keep alive (this is a bug)", "heartbeat_type", keepAlive.GetType())
}

firstIteration = false
Expand Down Expand Up @@ -308,7 +312,7 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre

var eventStream apievents.Stream
var sessionID session.ID
g.Debugf("CreateAuditStream connection from %v.", auth.User.GetName())
g.logger.DebugContext(stream.Context(), "CreateAuditStream connection", "identity", auth.User.GetName())
streamStart := time.Now()
processed := int64(0)
counter := 0
Expand All @@ -319,7 +323,7 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
return
case statusUpdate := <-eventStream.Status():
if err := stream.Send(&statusUpdate); err != nil {
g.WithError(err).Debugf("Failed to send status update.")
g.logger.DebugContext(stream.Context(), "Failed to send status update", "error", err)
}
}
}
Expand All @@ -328,10 +332,10 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
closeStream := func(eventStream apievents.Stream) {
if err := eventStream.Close(auth.CloseContext()); err != nil {
if auth.CloseContext().Err() == nil {
g.WithError(err).Warn("Failed to flush close the stream.")
g.logger.WarnContext(stream.Context(), "Failed to flush close the stream", "error", err)
}
} else {
g.Debugf("Flushed and closed the stream.")
g.logger.DebugContext(stream.Context(), "Flushed and closed the stream")
}
}

Expand All @@ -342,7 +346,7 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
}
if err != nil {
if stream.Context().Err() == nil {
g.WithError(err).Debug("Failed to receive stream request.")
g.logger.DebugContext(stream.Context(), "Failed to receive stream request", "error", err)
}
return trace.Wrap(err)
}
Expand All @@ -355,11 +359,11 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
// Log the reason why audit stream creation failed. This will
// surface things like AWS/GCP/MinIO credential/configuration
// errors.
g.Errorf("Failed to create audit stream: %q.", err)
g.logger.ErrorContext(stream.Context(), "Failed to create audit stream", "error", err)
return trace.Wrap(err)
}
sessionID = session.ID(create.SessionID)
g.Debugf("Created stream for session %v", sessionID)
g.logger.DebugContext(stream.Context(), "Created stream for session", "session_id", sessionID)
go forwardEvents(eventStream)
defer closeStream(eventStream)
} else if resume := request.GetResumeStream(); resume != nil {
Expand All @@ -370,7 +374,7 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
if err != nil {
return trace.Wrap(err)
}
g.Debugf("Resumed stream for session %v", resume.SessionID)
g.logger.DebugContext(stream.Context(), "Resumed stream for session", "session_id", resume.SessionID)
go forwardEvents(eventStream)
defer closeStream(eventStream)
} else if complete := request.GetCompleteStream(); complete != nil {
Expand Down Expand Up @@ -408,7 +412,7 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
return trace.Wrap(err)
}
}
g.Debugf("Completed stream for session %v", sessionID)
g.logger.DebugContext(stream.Context(), "Completed stream for session", "session_id", sessionID)
return nil
} else if flushAndClose := request.GetFlushAndCloseStream(); flushAndClose != nil {
if eventStream == nil {
Expand All @@ -422,7 +426,7 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
}
event, err := apievents.FromOneOf(*oneof)
if err != nil {
g.WithError(err).Debug("Failed to decode event.")
g.logger.DebugContext(stream.Context(), "Failed to decode event", "error", err)
return trace.Wrap(err)
}
// Currently only api/client.auditStreamer calls with an event
Expand Down Expand Up @@ -451,8 +455,15 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
if err != nil {
switch {
case events.IsPermanentEmitError(err):
g.WithError(err).WithField("event", event).
Error("Failed to EmitAuditEvent due to a permanent error. Event wil be omitted.")
g.logger.ErrorContext(stream.Context(), "Failed to EmitAuditEvent due to a permanent error, event wil be omitted",
slog.Any("error", err),
slog.Group("event",
slog.String("type", event.GetType()),
slog.String("code", event.GetCode()),
slog.String("id", event.GetID()),
slog.Int64("index", event.GetIndex()),
),
)
continue
default:
return trace.Wrap(err)
Expand All @@ -465,15 +476,18 @@ func (g *GRPCServer) CreateAuditStream(stream authpb.AuthService_CreateAuditStre
if counter%logInterval == 0 {
if seconds > 0 {
kbytes := float64(processed) / 1000
g.Debugf("Processed %v events, tx rate kbytes %v/second.", counter, kbytes/float64(seconds))
g.logger.DebugContext(stream.Context(), "Processed events", "event_count", counter, "tx_rate", kbytes/float64(seconds))
}
}
diff := time.Since(start)
if diff > 100*time.Millisecond {
g.Warningf("RecordEvent(%v) took longer than 100ms: %v", event.GetType(), time.Since(event.GetTime()))
g.logger.WarnContext(stream.Context(), "RecordEvent took longer than 100ms",
"event_type", event.GetType(),
"duration", time.Since(event.GetTime()),
)
}
} else {
g.Errorf("Rejecting unsupported stream request: %v.", request)
g.logger.ErrorContext(stream.Context(), "Rejecting unsupported stream request", "request", request)
return trace.BadParameter("unsupported stream request")
}
}
Expand Down Expand Up @@ -569,7 +583,7 @@ func (g *GRPCServer) GenerateUserCerts(ctx context.Context, req *authpb.UserCert
return nil, trace.Wrap(err)
}
if err := validateUserCertsRequest(auth, req); err != nil {
g.Entry.Debugf("Validation of user certs request failed: %v", err)
g.logger.DebugContext(ctx, "Validation of user certs request failed", "error", err)
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -645,7 +659,7 @@ func (g *GRPCServer) generateUserSingleUseCerts(ctx context.Context, actx *grpcC
actx,
*req)
if err != nil {
g.Entry.Warningf("Failed to generate single-use cert: %v", err)
g.logger.WarnContext(ctx, "Failed to generate single-use cert", "error", err)
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -806,7 +820,10 @@ func (g *GRPCServer) GetInstances(filter *types.InstanceFilter, stream authpb.Au
for instances.Next() {
instance, ok := instances.Item().(*types.InstanceV1)
if !ok {
log.Warnf("Skipping unexpected instance type %T, expected %T.", instances.Item(), instance)
g.logger.WarnContext(stream.Context(), "Skipping unexpected instance type",
"instance_type", logutils.TypeAttr(instances.Item()),
"expected_instance_type", logutils.TypeAttr(instance),
)
continue
}
if err := stream.Send(instance); err != nil {
Expand Down Expand Up @@ -904,7 +921,10 @@ func (g *GRPCServer) GetCurrentUserRoles(_ *emptypb.Empty, stream authpb.AuthSer
for _, role := range roles {
v6, ok := role.(*types.RoleV6)
if !ok {
log.Warnf("expected type RoleV6, got %T for role %q", role, role.GetName())
g.logger.WarnContext(stream.Context(), "expected type RoleV6, got unexpected for role type",
"role_type", logutils.TypeAttr(role),
"role", role.GetName(),
)
return trace.Errorf("encountered unexpected role type")
}
if err := stream.Send(v6); err != nil {
Expand Down Expand Up @@ -2097,7 +2117,10 @@ func (g *GRPCServer) ListRoles(ctx context.Context, req *authpb.ListRolesRequest
for _, role := range rsp.Roles {
downgraded, err := maybeDowngradeRole(ctx, role)
if err != nil {
log.Warnf("Failed to downgrade role %q, this is a bug and may result in spurious access denied errors. err=%q", role.GetName(), err)
g.logger.WarnContext(ctx, "Failed to downgrade role, this is a bug and may result in spurious access denied errors",
"role", role.GetName(),
"error", err,
)
continue
}
downgradedRoles = append(downgradedRoles, downgraded)
Expand Down Expand Up @@ -2132,11 +2155,14 @@ func (g *GRPCServer) CreateRole(ctx context.Context, req *authpb.CreateRoleReque
return nil, trace.Wrap(err)
}

g.Debugf("%q role upserted", req.Role.GetName())
g.logger.DebugContext(ctx, "role upserted", "role_name", req.Role.GetName())

v6, ok := created.(*types.RoleV6)
if !ok {
log.Warnf("expected type RoleV6, got %T for role %q", created, created.GetName())
g.logger.WarnContext(ctx, "expected type RoleV6, got unexpected type",
"role_type", logutils.TypeAttr(created),
"role", created.GetName(),
)
return nil, trace.BadParameter("encountered unexpected role type")
}

Expand Down Expand Up @@ -2168,11 +2194,14 @@ func (g *GRPCServer) UpdateRole(ctx context.Context, req *authpb.UpdateRoleReque
return nil, trace.Wrap(err)
}

g.Debugf("%q role upserted", req.Role.GetName())
g.logger.DebugContext(ctx, "role upserted", "role", req.Role.GetName())

v6, ok := updated.(*types.RoleV6)
if !ok {
log.Warnf("expected type RoleV6, got %T for role %q", updated, updated.GetName())
g.logger.WarnContext(ctx, "expected type RoleV6, got unexpected type",
"role_type", logutils.TypeAttr(updated),
"role", updated.GetName(),
)
return nil, trace.BadParameter("encountered unexpected role type")
}

Expand Down Expand Up @@ -2204,11 +2233,14 @@ func (g *GRPCServer) UpsertRoleV2(ctx context.Context, req *authpb.UpsertRoleReq
return nil, trace.Wrap(err)
}

g.Debugf("%q role upserted", req.Role.GetName())
g.logger.DebugContext(ctx, "role upserted", "role", req.Role.GetName())

v6, ok := upserted.(*types.RoleV6)
if !ok {
log.Warnf("expected type RoleV6, got %T for role %q", upserted, upserted.GetName())
g.logger.WarnContext(ctx, "expected type RoleV6, got unexpected type",
"role_type", logutils.TypeAttr(upserted),
"role", upserted.GetName(),
)
return nil, trace.BadParameter("encountered unexpected role type")
}

Expand All @@ -2231,7 +2263,7 @@ func (g *GRPCServer) DeleteRole(ctx context.Context, req *authpb.DeleteRoleReque
return nil, trace.Wrap(err)
}

g.Debugf("%q role deleted", req.GetName())
g.logger.DebugContext(ctx, "role deleted", "role", req.GetName())

return &emptypb.Empty{}, nil
}
Expand Down Expand Up @@ -2907,7 +2939,10 @@ func (g *GRPCServer) GetServerInfos(_ *emptypb.Empty, stream authpb.AuthService_
for infos.Next() {
si, ok := infos.Item().(*types.ServerInfoV1)
if !ok {
log.Warnf("Skipping unexpected instance type %T, expected %T.", infos.Item(), si)
g.logger.WarnContext(stream.Context(), "expected type ServerInfoV1, got unexpected type",
"server_info_type", logutils.TypeAttr(infos.Item()),
"server_info_name", infos.Item().GetName(),
)
}
if err := stream.Send(si); err != nil {
infos.Done()
Expand Down Expand Up @@ -3836,7 +3871,7 @@ func (g *GRPCServer) UpsertWindowsDesktopService(ctx context.Context, service *t
// the closest thing we have to a public IP for the service.
clientAddr, err := authz.ClientSrcAddrFromContext(ctx)
if err != nil {
g.Logger.WithError(err).Warn("error getting client address from context")
g.logger.WarnContext(ctx, "error getting client address from context", "error", err)
return nil, status.Errorf(codes.FailedPrecondition, "client address not found in request context")
}
service.Spec.Addr = utils.ReplaceLocalhost(service.GetAddr(), clientAddr.String())
Expand Down Expand Up @@ -4223,7 +4258,7 @@ func (g *GRPCServer) CreateSessionTracker(ctx context.Context, req *authpb.Creat
}

if req.SessionTracker == nil {
g.Errorf("Missing SessionTracker in CreateSessionTrackerRequest. This can be caused by an outdated Teleport node running against your cluster.")
g.logger.ErrorContext(ctx, "Missing SessionTracker in CreateSessionTrackerRequest, this can be caused by an outdated Teleport node running against your cluster")
return nil, trace.BadParameter("missing SessionTracker from CreateSessionTrackerRequest")
}

Expand Down Expand Up @@ -5078,7 +5113,13 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
log.Debugf("gRPC(SERVER): keep alive %v count: %v.", cfg.KeepAlivePeriod, cfg.KeepAliveCount)

logger := slog.With(teleport.ComponentKey, teleport.Component(teleport.ComponentAuth, teleport.ComponentGRPC))

logger.DebugContext(context.Background(), "creating gRPC server",
"keep_alive_period", cfg.KeepAlivePeriod,
"keep_alive_count", cfg.KeepAliveCount,
)

// httplib.TLSCreds are explicitly used instead of credentials.NewTLS because the latter
// modifies the tls.Config.NextProtos which causes problems due to multiplexing on the auth
Expand Down Expand Up @@ -5242,20 +5283,18 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {

authServer := &GRPCServer{
APIConfig: cfg.APIConfig,
Entry: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.Component(teleport.ComponentAuth, teleport.ComponentGRPC),
}),
server: server,
logger: logger,
server: server,
}

if en := os.Getenv("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT"); en != "" {
inflightLimit, err := strconv.ParseInt(en, 10, 64)
if err != nil {
log.Error("Failed to parse the TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT envvar, limit will not be enforced.")
logger.ErrorContext(context.Background(), "Failed to parse the TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT envvar, limit will not be enforced")
inflightLimit = -1
}
if inflightLimit == 0 {
log.Warn("TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT is set to 0, no CreateAuditStream RPCs will be allowed.")
logger.WarnContext(context.Background(), "TELEPORT_UNSTABLE_CREATEAUDITSTREAM_INFLIGHT_LIMIT is set to 0, no CreateAuditStream RPCs will be allowed")
}
metrics.RegisterPrometheusCollectors(
createAuditStreamAcceptedTotalMetric,
Expand Down

0 comments on commit 8acd795

Please sign in to comment.