Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Propagate request id on incoming and outgoing requests (#582)
Browse files Browse the repository at this point in the history
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
EngHabu authored Jul 13, 2023
1 parent b93457f commit fb359bc
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"strings"
"time"

"google.golang.org/grpc/metadata"
"k8s.io/apimachinery/pkg/util/rand"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"

Expand Down Expand Up @@ -77,7 +80,8 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) {

logger.Infof(ctx, "Registering default middleware with blanket auth validation")
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor))
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(
RequestIDInterceptor, auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor))

// Not yet implemented for streaming
var chainedUnaryInterceptors grpc.UnaryServerInterceptor
Expand Down Expand Up @@ -228,11 +232,50 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.
return nil, errors.Wrap(err, "error registering signal service")
}

mux.Handle("/", gwmux)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
ctx := GetOrGenerateRequestIDForRequest(r)
gwmux.ServeHTTP(w, r.WithContext(ctx))
})

return mux, nil
}

// RequestIDInterceptor is a server interceptor that sets the request id on the context for any incoming calls.
func RequestIDInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(GetOrGenerateRequestIDForGRPC(ctx), req)
}

// GetOrGenerateRequestIDForGRPC returns a context with request id set from the context or from grpc metadata if it exists,
// otherwise it generates a new one.
func GetOrGenerateRequestIDForGRPC(ctx context.Context) context.Context {
if ctx.Value(contextutils.RequestIDKey) != nil {
return ctx
} else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 {
return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0])
} else {
return contextutils.WithRequestID(ctx, generateRequestID())
}
}

// GetOrGenerateRequestIDForRequest returns a context with request id set from the context or from metadata if it exists,
// otherwise it generates a new one.
func GetOrGenerateRequestIDForRequest(req *http.Request) context.Context {
ctx := req.Context()
if ctx.Value(contextutils.RequestIDKey) != nil {
return ctx
} else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 {
return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0])
} else if req.Header != nil && req.Header.Get(contextutils.RequestIDKey.String()) != "" {
return contextutils.WithRequestID(ctx, req.Header.Get(contextutils.RequestIDKey.String()))
} else {
return contextutils.WithRequestID(ctx, generateRequestID())
}
}

func generateRequestID() string {
return "a-" + rand.String(20)
}

func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig,
authCfg *authConfig.Config, storageConfig *storage.Config,
additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error {
Expand Down

0 comments on commit fb359bc

Please sign in to comment.