diff --git a/pkg/server/service.go b/pkg/server/service.go index f3b27416f..4a7983087 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -9,6 +9,9 @@ import ( "strings" "time" + "google.golang.org/grpc/metadata" + "k8s.io/apimachinery/pkg/util/rand" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" @@ -77,7 +80,8 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { logger.Infof(ctx, "Registering default middleware with blanket auth validation") - pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor)) + pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer( + RequestIDInterceptor, auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor)) // Not yet implemented for streaming var chainedUnaryInterceptors grpc.UnaryServerInterceptor @@ -228,11 +232,50 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig. return nil, errors.Wrap(err, "error registering signal service") } - mux.Handle("/", gwmux) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + ctx := GetOrGenerateRequestIDForRequest(r) + gwmux.ServeHTTP(w, r.WithContext(ctx)) + }) return mux, nil } +// RequestIDInterceptor is a server interceptor that sets the request id on the context for any incoming calls. +func RequestIDInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return handler(GetOrGenerateRequestIDForGRPC(ctx), req) +} + +// GetOrGenerateRequestIDForGRPC returns a context with request id set from the context or from grpc metadata if it exists, +// otherwise it generates a new one. +func GetOrGenerateRequestIDForGRPC(ctx context.Context) context.Context { + if ctx.Value(contextutils.RequestIDKey) != nil { + return ctx + } else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 { + return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0]) + } else { + return contextutils.WithRequestID(ctx, generateRequestID()) + } +} + +// GetOrGenerateRequestIDForRequest returns a context with request id set from the context or from metadata if it exists, +// otherwise it generates a new one. +func GetOrGenerateRequestIDForRequest(req *http.Request) context.Context { + ctx := req.Context() + if ctx.Value(contextutils.RequestIDKey) != nil { + return ctx + } else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 { + return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0]) + } else if req.Header != nil && req.Header.Get(contextutils.RequestIDKey.String()) != "" { + return contextutils.WithRequestID(ctx, req.Header.Get(contextutils.RequestIDKey.String())) + } else { + return contextutils.WithRequestID(ctx, generateRequestID()) + } +} + +func generateRequestID() string { + return "a-" + rand.String(20) +} + func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config, storageConfig *storage.Config, additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error {