Skip to content

Commit

Permalink
Merge pull request #1402 from authzed/improve-mw-assertion-api
Browse files Browse the repository at this point in the history
disambiguate middleware dependency assertion in streaming APIs
  • Loading branch information
vroldanbet authored Jun 21, 2023
2 parents ac5f23a + 6ba179d commit bfc1b2c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 33 deletions.
6 changes: 2 additions & 4 deletions pkg/cmd/server/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,12 @@ func DefaultUnaryMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, e
NewUnaryMiddleware().
WithName(DefaultMiddlewareGRPCProm).
WithInterceptor(grpcprom.UnaryServerInterceptor).
EnsureNotExecuted(DefaultMiddlewareGRPCAuth).
Done(),

NewUnaryMiddleware().
WithName(DefaultMiddlewareGRPCAuth).
WithInterceptor(grpcauth.UnaryServerInterceptor(authFunc)).
EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm).
EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
Done(),

NewUnaryMiddleware().
Expand Down Expand Up @@ -229,13 +228,12 @@ func DefaultStreamingMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFun
NewStreamMiddleware().
WithName(DefaultMiddlewareGRPCProm).
WithInterceptor(grpcprom.StreamServerInterceptor).
EnsureAlreadyExecuted(DefaultMiddlewareGRPCAuth).
Done(),

NewStreamMiddleware().
WithName(DefaultMiddlewareGRPCAuth).
WithInterceptor(grpcauth.StreamServerInterceptor(authFunc)).
EnsureNotExecuted(DefaultMiddlewareGRPCProm).
EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports auth failures
Done(),

NewStreamMiddleware().
Expand Down
85 changes: 57 additions & 28 deletions pkg/cmd/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,15 @@ type streamOrderAssertion struct {
}

func (o streamOrderAssertion) RecvMsg(m any) error {
if err := mustHaveExecuted(o.Context(), o.alreadyExecuted); err != nil {
if err := mustHaveExecuted(o.Context(), streamExecuted{}, o.alreadyExecuted); err != nil {
return err
}

if err := mustHaveNotExecuted(o.Context(), o.notExecuted); err != nil {
if err := mustHaveNotExecuted(o.Context(), streamExecuted{}, o.notExecuted); err != nil {
return err
}

mustMarkAsExecuted(o.Context(), o.name)
mustMarkAsExecuted(o.Context(), streamExecuted{}, o.name)
err := o.ServerStream.RecvMsg(m)
return err
}
Expand All @@ -248,11 +248,13 @@ func NewStreamMiddleware() *StreamOrderEnforcerBuilder {
}

type StreamOrderEnforcerBuilder struct {
name string
streamInterceptor grpc.StreamServerInterceptor
internal bool
executed string
notExecuted string
name string
streamInterceptor grpc.StreamServerInterceptor
internal bool
interceptorExecuted string
interceptorNotExecuted string
streamWrapperExecuted string
streamWrapperNotExecuted string
}

func (soeb *StreamOrderEnforcerBuilder) WithName(name string) *StreamOrderEnforcerBuilder {
Expand All @@ -270,13 +272,23 @@ func (soeb *StreamOrderEnforcerBuilder) WithInternal(internal bool) *StreamOrder
return soeb
}

func (soeb *StreamOrderEnforcerBuilder) EnsureAlreadyExecuted(name string) *StreamOrderEnforcerBuilder {
soeb.executed = name
func (soeb *StreamOrderEnforcerBuilder) EnsureWrapperAlreadyExecuted(name string) *StreamOrderEnforcerBuilder {
soeb.streamWrapperExecuted = name
return soeb
}

func (soeb *StreamOrderEnforcerBuilder) EnsureNotExecuted(name string) *StreamOrderEnforcerBuilder {
soeb.notExecuted = name
func (soeb *StreamOrderEnforcerBuilder) EnsureWrapperNotExecuted(name string) *StreamOrderEnforcerBuilder {
soeb.streamWrapperNotExecuted = name
return soeb
}

func (soeb *StreamOrderEnforcerBuilder) EnsureInterceptorAlreadyExecuted(name string) *StreamOrderEnforcerBuilder {
soeb.interceptorExecuted = name
return soeb
}

func (soeb *StreamOrderEnforcerBuilder) EnsureInterceptorNotExecuted(name string) *StreamOrderEnforcerBuilder {
soeb.interceptorNotExecuted = name
return soeb
}

Expand All @@ -294,15 +306,30 @@ func (soeb *StreamOrderEnforcerBuilder) Done() ReferenceableMiddleware[grpc.Stre
Internal: soeb.internal,
Middleware: func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wss := middleware.WrapServerStream(ss)
if wss.WrappedContext.Value(interceptorsExecuted) == nil {
if wss.WrappedContext.Value(streamExecuted{}) == nil {
handle := executedHandle{executed: make(map[string]struct{}, 0)}
wss.WrappedContext = context.WithValue(wss.WrappedContext, streamExecuted{}, &handle)
}
if wss.WrappedContext.Value(interceptorsExecuted{}) == nil {
handle := executedHandle{executed: make(map[string]struct{}, 0)}
wss.WrappedContext = context.WithValue(wss.WrappedContext, interceptorsExecuted, &handle)
wss.WrappedContext = context.WithValue(wss.WrappedContext, interceptorsExecuted{}, &handle)
}

if err := mustHaveExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.interceptorExecuted); err != nil {
return err
}

if err := mustHaveNotExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.interceptorNotExecuted); err != nil {
return err
}

mustMarkAsExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.name)

wrappedStream := streamOrderAssertion{
ServerStream: wss,
name: soeb.name,
alreadyExecuted: soeb.executed,
notExecuted: soeb.notExecuted,
alreadyExecuted: soeb.streamWrapperExecuted,
notExecuted: soeb.streamWrapperNotExecuted,
}
return soeb.streamInterceptor(srv, wrappedStream, info, handler)
},
Expand Down Expand Up @@ -359,31 +386,31 @@ func (soeb *UnaryOrderEnforcerBuilder) Done() ReferenceableMiddleware[grpc.Unary
Name: soeb.name,
Internal: soeb.internal,
Middleware: func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if ctx.Value(interceptorsExecuted) == nil {
if ctx.Value(interceptorsExecuted{}) == nil {
handle := executedHandle{executed: make(map[string]struct{}, 0)}
ctx = context.WithValue(ctx, interceptorsExecuted, &handle)
ctx = context.WithValue(ctx, interceptorsExecuted{}, &handle)
}

if err := mustHaveExecuted(ctx, soeb.alreadyExecuted); err != nil {
if err := mustHaveExecuted(ctx, interceptorsExecuted{}, soeb.alreadyExecuted); err != nil {
return nil, err
}

if err := mustHaveNotExecuted(ctx, soeb.notExecuted); err != nil {
if err := mustHaveNotExecuted(ctx, interceptorsExecuted{}, soeb.notExecuted); err != nil {
return nil, err
}

mustMarkAsExecuted(ctx, soeb.name)
mustMarkAsExecuted(ctx, interceptorsExecuted{}, soeb.name)
return soeb.interceptor(ctx, req, info, handler)
},
}
}

func mustHaveNotExecuted(ctx context.Context, notExecuted string) error {
func mustHaveNotExecuted(ctx context.Context, handleKey any, notExecuted string) error {
if notExecuted == "" {
return nil
}

val := ctx.Value(interceptorsExecuted)
val := ctx.Value(handleKey)
if val == nil {
return fmt.Errorf("interception order validation bookkeeping not present in context")
}
Expand All @@ -396,12 +423,12 @@ func mustHaveNotExecuted(ctx context.Context, notExecuted string) error {
return nil
}

func mustHaveExecuted(ctx context.Context, expectedExecuted string) error {
func mustHaveExecuted(ctx context.Context, handleKey any, expectedExecuted string) error {
if expectedExecuted == "" {
return nil
}

val := ctx.Value(interceptorsExecuted)
val := ctx.Value(handleKey)
if val == nil {
return spiceerrors.MustBugf("interception order validation bookkeeping not present in context")
}
Expand All @@ -414,8 +441,8 @@ func mustHaveExecuted(ctx context.Context, expectedExecuted string) error {
return fmt.Errorf("expected interceptor %s to be already executed", expectedExecuted)
}

func mustMarkAsExecuted(ctx context.Context, name string) {
val := ctx.Value(interceptorsExecuted)
func mustMarkAsExecuted(ctx context.Context, handleKey any, name string) {
val := ctx.Value(handleKey)
if val == nil {
panic("handle should exist")
} else {
Expand All @@ -428,4 +455,6 @@ type executedHandle struct {
executed map[string]struct{}
}

var interceptorsExecuted = struct{}{}
type interceptorsExecuted struct{}

type streamExecuted struct{}
2 changes: 1 addition & 1 deletion pkg/cmd/server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ func TestIncorrectOrderAssertionFails(t *testing.T) {
NewStreamMiddleware().
WithName("test").
WithInterceptor(noopStreaming).
EnsureAlreadyExecuted("does-not-exist").
EnsureWrapperAlreadyExecuted("does-not-exist").
Done(),
},
},
Expand Down

0 comments on commit bfc1b2c

Please sign in to comment.