From 3420663c6688be098f3e4d703a64a42d1981a3d5 Mon Sep 17 00:00:00 2001 From: Moe Jangda Date: Thu, 3 Oct 2024 15:33:45 -0500 Subject: [PATCH] move verb export check to before req body validation --- backend/controller/controller.go | 33 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/backend/controller/controller.go b/backend/controller/controller.go index 4c22aefb76..de5782af06 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -1012,34 +1012,35 @@ func (s *Service) callWithRequest( return nil, err } - err := ingress.ValidateCallBody(req.Msg.Body, verb, sch) + callers, err := headers.GetCallers(req.Header()) if err != nil { - observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("invalid request: invalid call body")) + observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("failed to get callers")) return nil, err } + var currentCaller *schema.Ref + if len(callers) > 0 { + currentCaller = callers[len(callers)-1] + } + module := verbRef.Module - route, ok := sstate.routes[module] - if !ok { - observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("no routes for module")) - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("no routes for module %q", module)) + if currentCaller.Module != module && !verb.IsExported() { + observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("invalid request: verb not exported")) + return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("verb %q is not exported", verbRef)) } - client := s.clientsForEndpoint(route.Endpoint) - callers, err := headers.GetCallers(req.Header()) + err = ingress.ValidateCallBody(req.Msg.Body, verb, sch) if err != nil { - observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("failed to get callers")) + observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("invalid request: invalid call body")) return nil, err } - if !verb.IsExported() { - for _, caller := range callers { - if caller.Module != module { - observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("invalid request: verb not exported")) - return nil, connect.NewError(connect.CodePermissionDenied, fmt.Errorf("verb %q is not exported", verbRef)) - } - } + route, ok := sstate.routes[module] + if !ok { + observability.Calls.Request(ctx, req.Msg.Verb, start, optional.Some("no routes for module")) + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("no routes for module %q", module)) } + client := s.clientsForEndpoint(route.Endpoint) var requestKey model.RequestKey isNewRequestKey := false