diff --git a/bad_route.go b/bad_route.go index 472b2762..052dccdc 100644 --- a/bad_route.go +++ b/bad_route.go @@ -12,16 +12,24 @@ import ( // Twirp specification, mount this handler at the root of your API (so that it // handles any requests for invalid protobuf methods). func NewBadRouteHandler(opts ...HandlerOption) *Handler { + wrapped := Func(badRouteUnaryImpl) + if ic := ConfiguredHandlerInterceptor(opts...); ic != nil { + wrapped = ic.Wrap(wrapped) + } return NewHandler( - "", "", "", // protobuf method, service, package names - func() proto.Message { return &emptypb.Empty{} }, // unused req msg - func(ctx context.Context, _ proto.Message) (proto.Message, error) { - path := "???" - if md, ok := HandlerMeta(ctx); ok { - path = md.Spec.Path - } - return nil, Wrap(CodeNotFound, newBadRouteError(path)) + "", "", "", // protobuf package, service, method names + func(ctx context.Context, stream Stream) { + defer stream.CloseReceive() + _, err := wrapped(ctx, &emptypb.Empty{}) + _ = stream.CloseSend(err) }, - opts..., ) } + +func badRouteUnaryImpl(ctx context.Context, _ proto.Message) (proto.Message, error) { + path := "???" + if md, ok := HandlerMeta(ctx); ok { + path = md.Spec.Path + } + return nil, Wrap(CodeNotFound, newBadRouteError(path)) +} diff --git a/client.go b/client.go index 2cb84afa..661fcb3e 100644 --- a/client.go +++ b/client.go @@ -1,19 +1,10 @@ package rerpc import ( - "bytes" "context" - "errors" - "io" - "io/ioutil" + "fmt" "net/http" "net/url" - "strconv" - "time" - - "google.golang.org/protobuf/proto" - - statuspb "github.com/rerpc/rerpc/internal/status/v1" ) // Doer is the transport-level interface reRPC expects HTTP clients to @@ -23,8 +14,11 @@ type Doer interface { } type callCfg struct { + Package string + Service string + Method string EnableGzipRequest bool - MaxResponseBytes int + MaxResponseBytes int64 Interceptor Interceptor Hooks *Hooks } @@ -44,47 +38,62 @@ type CallOption interface { // To see an example of how Client is used in the generated code, see the // internal/ping/v1test package. type Client struct { - doer Doer - url string - methodFQN string - serviceFQN string - packageFQN string - newResponse func() proto.Message - opts []CallOption + doer Doer + baseURL string + pkg, service, method string + opts []CallOption } -// NewClient creates a Client. The supplied URL must be the full, -// method-specific URL, without trailing slashes. The supplied method, service, -// and package must be fully-qualified protobuf identifiers, and the -// newResponse constructor must be safe to call concurrently. Any options -// passed here apply to all calls made with this client. +// NewClient creates a Client. The supplied URL must be the root URL of the +// server's API, without a trailing slash (e.g., https://api.acme.com or +// https://acme.com/grpc). The supplied package, service, and method must be +// protobuf identifiers. Any options passed here apply to all calls made with +// this client. // -// For example, the URL https://api.acme.com/acme.foo.v1.FooService/Bar -// corresponds to method "acme.foo.v1.FooService.Bar", service -// "acme.foo.v1.FooService", and package "acme.foo.v1". In that case, the -// newResponse constructor would be: -// func() proto.Message { -// return &foopb.BarResponse{} -// } +// For example, to call the URL +// https://api.acme.com/acme.foo.v1.FooService/Bar, you'd pass the URL +// "https://api.acme.com", the package "acme.foo.v1", the service "FooService", +// and the method "Bar". // // Remember that NewClient is usually called from generated code - most users -// won't need to deal with long URLs or protobuf identifiers directly. -func NewClient(doer Doer, url, methodFQN, serviceFQN, packageFQN string, newResponse func() proto.Message, opts ...CallOption) *Client { +// won't need to deal with it directly. +// +// TODO: refactor this into a one-shot call. There's virtually no work that +// happens at the client level outside generated code. +func NewClient(doer Doer, baseURL, pkg, service, method string, opts ...CallOption) *Client { return &Client{ - doer: doer, - url: url, - methodFQN: methodFQN, - serviceFQN: serviceFQN, - packageFQN: packageFQN, - newResponse: newResponse, - opts: opts, + doer: doer, + baseURL: baseURL, + pkg: pkg, + service: service, + method: method, + opts: opts, } } -// Call the remote procedure. Any options passed apply only to the current -// call. -func (c *Client) Call(ctx context.Context, req proto.Message, opts ...CallOption) (proto.Message, error) { - var cfg callCfg +// Call creates a stream for the remote procedure. Any options passed apply +// only to the current call. +func (c *Client) Call(ctx context.Context, opts ...CallOption) Stream { + md, ok := CallMeta(ctx) + if !ok { + ctx = c.Context(ctx) + md, _ = CallMeta(ctx) + } + spec := md.Spec + methodURL := fmt.Sprintf("%s/%s.%s/%s", c.baseURL, spec.Package, spec.Service, spec.Method) + next := CallStreamFunc(func(ctx context.Context) Stream { + return newClientStream(ctx, c.doer, methodURL, spec.ReadMaxBytes, spec.RequestCompression == CompressionGzip) + }) + // TODO: apply interceptors + return next(ctx) +} + +func (c *Client) Context(ctx context.Context, opts ...CallOption) context.Context { + cfg := callCfg{ + Package: c.pkg, + Service: c.service, + Method: c.method, + } for _, opt := range c.opts { opt.applyToCall(&cfg) } @@ -92,24 +101,15 @@ func (c *Client) Call(ctx context.Context, req proto.Message, opts ...CallOption opt.applyToCall(&cfg) } - next := Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - // Take care not to return a typed nil from this function. - res, err := c.call(ctx, req, &cfg) - if err != nil { - return nil, err - } - return res, nil - }) - if cfg.Interceptor != nil { - next = cfg.Interceptor.Wrap(next) - } spec := &Specification{ - Method: c.methodFQN, - Service: c.serviceFQN, - Package: c.packageFQN, + Package: cfg.Package, + Service: cfg.Service, + Method: cfg.Method, RequestCompression: CompressionGzip, + ReadMaxBytes: cfg.MaxResponseBytes, } - if url, err := url.Parse(c.url); err == nil { + methodURL := fmt.Sprintf("%s/%s.%s/%s", c.baseURL, spec.Package, spec.Service, spec.Method) + if url, err := url.Parse(methodURL); err == nil { spec.Path = url.Path } if !cfg.EnableGzipRequest { @@ -121,132 +121,5 @@ func (c *Client) Call(ctx context.Context, req proto.Message, opts ...CallOption reqHeader.Set("Grpc-Encoding", spec.RequestCompression) reqHeader.Set("Grpc-Accept-Encoding", acceptEncodingValue) // always advertise identity & gzip reqHeader.Set("Te", "trailers") - ctx = NewCallContext(ctx, *spec, reqHeader, make(http.Header)) - return next(ctx, req) -} - -func (c *Client) call(ctx context.Context, req proto.Message, cfg *callCfg) (proto.Message, *Error) { - md, hasMD := CallMeta(ctx) - if !hasMD { - return nil, errorf(CodeInternal, "no call metadata available on context") - } - - if deadline, ok := ctx.Deadline(); ok { - untilDeadline := time.Until(deadline) - if untilDeadline <= 0 { - return nil, errorf(CodeDeadlineExceeded, "no time to make RPC: timeout is %v", untilDeadline) - } - if enc, err := encodeTimeout(untilDeadline); err == nil { - // Tests verify that the error in encodeTimeout is unreachable, so we - // should be safe without observability for the error case. - md.req.raw.Set("Grpc-Timeout", enc) - } - } - - body := &bytes.Buffer{} - if err := marshalLPM(ctx, body, req, md.Spec.RequestCompression, 0 /* maxBytes */, cfg.Hooks); err != nil { - return nil, errorf(CodeInvalidArgument, "can't marshal request as protobuf: %w", err) - } - - request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, body) - if err != nil { - return nil, errorf(CodeInternal, "can't create HTTP request: %w", err) - } - request.Header = md.req.raw - - response, err := c.doer.Do(request) - if err != nil { - if errors.Is(err, context.Canceled) { - return nil, errorf(CodeCanceled, "context canceled") - } - if errors.Is(err, context.DeadlineExceeded) { - return nil, errorf(CodeDeadlineExceeded, "context deadline exceeded") - } - // Error message comes from our networking stack, so it's safe to expose. - return nil, wrap(CodeUnknown, err) - } - defer response.Body.Close() - defer io.Copy(ioutil.Discard, response.Body) - *md.res = NewImmutableHeader(response.Header) - - if response.StatusCode != http.StatusOK { - code := CodeUnknown - if c, ok := httpToGRPC[response.StatusCode]; ok { - code = c - } - return nil, errorf(code, "HTTP status %v", response.StatusCode) - } - compression := response.Header.Get("Grpc-Encoding") - if compression == "" { - compression = CompressionIdentity - } - switch compression { - case CompressionIdentity, CompressionGzip: - default: - // Per https://github.com/grpc/grpc/blob/master/doc/compression.md, we - // should return CodeInternal and specify acceptable compression(s) (in - // addition to setting the Grpc-Accept-Encoding header). - return nil, errorf( - CodeInternal, - "unknown compression %q: accepted grpc-encoding values are %v", - compression, - acceptEncodingValue, - ) - } - - // When there's no body, errors sent from the first-party gRPC servers will - // be in the headers. - if err := extractError(response.Header); err != nil { - return nil, err - } - - res := c.newResponse() - // Handling this error is a little complicated - read on. - unmarshalErr := unmarshalLPM(response.Body, res, compression, cfg.MaxResponseBytes) - // To ensure that we've read the trailers, read the body to completion. - io.Copy(io.Discard, response.Body) - serverErr := extractError(response.Trailer) - if serverErr != nil { - // Server sent us an error. In this case, we don't care if the - // length-prefixed message was corrupted and unmarshalErr is non-nil. - return nil, serverErr - } else if unmarshalErr != nil { - // Server thinks response was successful, so unmarshalErr is real. - return nil, errorf(CodeUnknown, "server returned invalid protobuf: %w", unmarshalErr) - } - // Server thinks response was successful and so do we, so we're done. - return res, nil -} - -func extractError(h http.Header) *Error { - codeHeader := h.Get("Grpc-Status") - codeIsSuccess := (codeHeader == "" || codeHeader == "0") - if codeIsSuccess { - return nil - } - - code, err := strconv.ParseUint(codeHeader, 10 /* base */, 32 /* bitsize */) - if err != nil { - return errorf(CodeUnknown, "gRPC protocol error: got invalid error code %q", codeHeader) - } - message := percentDecode(h.Get("Grpc-Message")) - ret := wrap(Code(code), errors.New(message)) - - detailsBinaryEncoded := h.Get("Grpc-Status-Details-Bin") - if len(detailsBinaryEncoded) > 0 { - detailsBinary, err := decodeBinaryHeader(detailsBinaryEncoded) - if err != nil { - return errorf(CodeUnknown, "server returned invalid grpc-error-details-bin trailer: %w", err) - } - var status statuspb.Status - if err := proto.Unmarshal(detailsBinary, &status); err != nil { - return errorf(CodeUnknown, "server returned invalid protobuf for error details: %w", err) - } - ret.details = status.Details - // Prefer the protobuf-encoded data to the headers (grpc-go does this too). - ret.code = Code(status.Code) - ret.err = errors.New(status.Message) - } - - return ret + return NewCallContext(ctx, *spec, reqHeader, make(http.Header)) } diff --git a/cmd/protoc-gen-go-rerpc/rerpc.go b/cmd/protoc-gen-go-rerpc/rerpc.go index 1c424f3b..f3920bcb 100644 --- a/cmd/protoc-gen-go-rerpc/rerpc.go +++ b/cmd/protoc-gen-go-rerpc/rerpc.go @@ -18,6 +18,11 @@ const ( stringsPackage = protogen.GoImportPath("strings") ) +var ( + contextContext = contextPackage.Ident("Context") + protoMessage = protoPackage.Ident("Message") +) + func deprecated(g *protogen.GeneratedFile) { comment(g, "// Deprecated: do not use.") } @@ -111,7 +116,7 @@ func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { deprecated(g) } - return method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + + return method.GoName + "(ctx " + g.QualifiedGoIdent(contextContext) + ", req *" + g.QualifiedGoIdent(method.Input.GoIdent) + ", opts ..." + g.QualifiedGoIdent(rerpcPackage.Ident("CallOption")) + ") " + "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)" @@ -123,6 +128,7 @@ func clientImplementation(g *protogen.GeneratedFile, service *protogen.Service, for _, method := range unaryMethods(service) { g.P(unexport(method.GoName), " ", rerpcPackage.Ident("Client")) } + g.P("options []", rerpcPackage.Ident("CallOption")) g.P("}") g.P() @@ -131,7 +137,7 @@ func clientImplementation(g *protogen.GeneratedFile, service *protogen.Service, " service. Call options passed here apply to all calls made with this client.") g.P("//") comment(g, "The URL supplied here should be the base URL for the gRPC server ", - "(e.g., https://api.acme.com or https://acme.com/api/grpc).") + "(e.g., https://api.acme.com or https://acme.com/grpc).") if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { g.P("//") deprecated(g) @@ -141,17 +147,16 @@ func clientImplementation(g *protogen.GeneratedFile, service *protogen.Service, g.P("baseURL = ", stringsPackage.Ident("TrimRight"), `(baseURL, "/")`) g.P("return &", unexport(name), "{") for _, method := range unaryMethods(service) { - path := fmt.Sprintf("%s/%s", service.Desc.FullName(), method.Desc.Name()) g.P(unexport(method.GoName), ": *", rerpcPackage.Ident("NewClient"), "(") g.P("doer,") - g.P(`baseURL + "/`, path, `", // complete URL to call method`) - g.P(`"`, method.Desc.FullName(), `", // fully-qualified protobuf method`) - g.P(`"`, service.Desc.FullName(), `", // fully-qualified protobuf service`) - g.P(`"`, service.Desc.ParentFile().Package(), `", // fully-qualified protobuf package`) - g.P("func() proto.Message { return &", method.Output.GoIdent, "{} }, // response constructor") + g.P("baseURL,") + g.P(`"`, service.Desc.ParentFile().Package(), `", // protobuf package`) + g.P(`"`, service.Desc.Name(), `", // protobuf service`) + g.P(`"`, method.Desc.Name(), `", // protobuf method`) g.P("opts...,") g.P("),") } + g.P("options: opts,") g.P("}") g.P("}") g.P() @@ -170,11 +175,38 @@ func clientMethod(g *protogen.GeneratedFile, method *protogen.Method) { deprecated(g) } g.P("func (c *", unexport(method.Parent.GoName), "ClientReRPC) ", clientSignature(g, method), "{") - g.P("res, err := c.", unexport(method.GoName), ".Call(ctx, req, opts...)") + g.P("wrapped := ", rerpcPackage.Ident("Func"), "(func(ctx ", contextContext, ", msg ", protoMessage, ") (", protoMessage, ", error) {") + g.P("stream := c.", unexport(method.GoName), ".Call(ctx, opts...)") + g.P("if err := stream.Send(req); err != nil {") + g.P("_ = stream.CloseSend(err)") + g.P("_ = stream.CloseReceive()") + g.P("return nil, err") + g.P("}") + g.P("if err := stream.CloseSend(nil); err != nil {") + g.P("_ = stream.CloseReceive()") + g.P("return nil, err") + g.P("}") + g.P("var res ", method.Output.GoIdent) + g.P("if err := stream.Receive(&res); err != nil {") + g.P("_ = stream.CloseReceive()") + g.P("return nil, err") + g.P("}") + g.P("return &res, stream.CloseReceive()") + g.P("})") + g.P("mergedOpts := append([]", rerpcPackage.Ident("CallOption"), "{}, c.options...)") + g.P("mergedOpts = append(mergedOpts, opts...)") + g.P("if ic := ", rerpcPackage.Ident("ConfiguredCallInterceptor"), "(mergedOpts...); ic != nil {") + g.P("wrapped = ic.Wrap(wrapped)") + g.P("}") + g.P("res, err := wrapped(c.", unexport(method.GoName), ".Context(ctx, opts...), req)") g.P("if err != nil {") g.P("return nil, err") g.P("}") - g.P("return res.(*", method.Output.GoIdent, "), nil") + g.P("typed, ok := res.(*", method.Output.GoIdent, ")") + g.P("if !ok {") + g.P("return nil, ", rerpcPackage.Ident("Errorf"), "(", rerpcPackage.Ident("CodeInternal"), `, "expected response to be `, method.Output.Desc.FullName(), `, got %v", res.ProtoReflect().Descriptor().FullName())`) + g.P("}") + g.P("return typed, nil") g.P("}") g.P() } @@ -206,13 +238,12 @@ func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { deprecated(g) } - return method.GoName + "(" + g.QualifiedGoIdent(contextPackage.Ident("Context")) + + return method.GoName + "(" + g.QualifiedGoIdent(contextContext) + ", *" + g.QualifiedGoIdent(method.Input.GoIdent) + ") " + "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)" } func serverConstructor(g *protogen.GeneratedFile, service *protogen.Service, name string) { - sname := service.Desc.FullName() comment(g, "New", service.GoName, "HandlerReRPC wraps the service implementation", " in an HTTP handler. It returns the handler and the path on which to mount it.") if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { @@ -222,37 +253,54 @@ func serverConstructor(g *protogen.GeneratedFile, service *protogen.Service, nam g.P("func New", service.GoName, "HandlerReRPC(svc ", name, ", opts ...", rerpcPackage.Ident("HandlerOption"), ") (string, *", httpPackage.Ident("ServeMux"), ") {") g.P("mux := ", httpPackage.Ident("NewServeMux"), "()") + g.P("ic := ", rerpcPackage.Ident("ConfiguredHandlerInterceptor"), "(opts...)") g.P() + lastHandlerName := "" for _, method := range unaryMethods(service) { - path := fmt.Sprintf("%s/%s", sname, method.Desc.Name()) hname := unexport(string(method.Desc.Name())) - g.P(hname, " := ", rerpcPackage.Ident("NewHandler"), "(") - g.P(`"`, method.Desc.FullName(), `", // fully-qualified protobuf method`) - g.P(`"`, service.Desc.FullName(), `", // fully-qualified protobuf service`) - g.P(`"`, service.Desc.ParentFile().Package(), `", // fully-qualified protobuf package`) - g.P("func() ", protoPackage.Ident("Message"), " { return &", method.Input.GoIdent, "{} }, // request msg constructor") - g.P(rerpcPackage.Ident("Func"), "(func(ctx ", contextPackage.Ident("Context"), - ", req ", protoPackage.Ident("Message"), ") (", - protoPackage.Ident("Message"), ", error) {") + wrapped := hname + "Func" + lastHandlerName = hname + g.P(wrapped, " := ", rerpcPackage.Ident("Func"), "(func(ctx ", contextContext, ", req ", protoMessage, ") (", protoMessage, ", error) {") g.P("typed, ok := req.(*", method.Input.GoIdent, ")") g.P("if !ok {") g.P("return nil, ", rerpcPackage.Ident("Errorf"), "(") g.P(rerpcPackage.Ident("CodeInternal"), ",") - g.P(`"error in generated code: expected req to be a *`, method.Input.GoIdent, `, got a %T",`) - g.P("req,") + g.P(`"can't call `, method.Desc.FullName(), ` with a %v",`) + g.P("req.ProtoReflect().Descriptor().FullName(),") g.P(")") g.P("}") g.P("return svc.", method.GoName, "(ctx, typed)") + g.P("})") + g.P("if ic != nil {") + g.P(wrapped, " = ic.Wrap(", wrapped, ")") + g.P("}") + g.P(hname, " := ", rerpcPackage.Ident("NewHandler"), "(") + g.P(`"`, service.Desc.ParentFile().Package(), `", // protobuf package`) + g.P(`"`, service.Desc.Name(), `", // protobuf service`) + g.P(`"`, method.Desc.Name(), `", // protobuf method`) + g.P(rerpcPackage.Ident("HandlerStreamFunc"), "(func(ctx ", contextContext, ", stream ", rerpcPackage.Ident("Stream"), ") {") + g.P("defer stream.CloseReceive()") + g.P("var req ", method.Input.GoIdent) + g.P("if err := stream.Receive(&req); err != nil {") + g.P(" _ = stream.CloseSend(err)") + g.P("return") + g.P("}") + g.P("res, err := ", wrapped, "(ctx, &req)") + g.P("if err != nil {") + g.P("_ = stream.CloseSend(err)") + g.P("return") + g.P("}") + g.P("_ = stream.CloseSend(stream.Send(res))") g.P("}),") g.P("opts...,") g.P(")") - g.P(`mux.Handle("/`, path, `", `, hname, ")") + g.P("mux.Handle(", hname, ".Path(), ", hname, ")") g.P() } comment(g, "Respond to unknown protobuf methods with gRPC and Twirp's 404 equivalents.") g.P(`mux.Handle("/", `, rerpcPackage.Ident("NewBadRouteHandler"), "(opts...))") g.P() - g.P(`return "/`, sname, `/", mux`) + g.P("return ", lastHandlerName, ".ServicePath(), mux") g.P("}") g.P() } diff --git a/handler.go b/handler.go index 27613a56..f216b8d4 100644 --- a/handler.go +++ b/handler.go @@ -3,18 +3,10 @@ package rerpc import ( "compress/gzip" "context" - "encoding/json" "fmt" "io" - "io/ioutil" "net/http" - "strconv" "strings" - - "google.golang.org/protobuf/proto" - - statuspb "github.com/rerpc/rerpc/internal/status/v1" - "github.com/rerpc/rerpc/internal/twirp" ) var ( @@ -33,10 +25,13 @@ var ( type handlerCfg struct { DisableGzipResponse bool DisableTwirp bool - MaxRequestBytes int + MaxRequestBytes int64 Registrar *Registrar Interceptor Interceptor Hooks *Hooks + Package string + Service string + Method string } // A HandlerOption configures a Handler. @@ -72,57 +67,36 @@ func ServeTwirp(enable bool) HandlerOption { // To see an example of how Handler is used in the generated code, see the // internal/pingpb/v0 package. type Handler struct { - methodFQN string - serviceFQN string - packageFQN string - newRequest func() proto.Message - config handlerCfg - - // Handlers must either unary or stream, but not both. - unary Func - stream func( - context.Context, - http.ResponseWriter, - *http.Request, - string, // request compression - string, // response compression - *Hooks, - ) + config handlerCfg + implementation HandlerStreamFunc } -// NewHandler constructs a Handler. The supplied method, service, and package -// must be fully-qualified protobuf identifiers, and the newRequest constructor -// must be safe to call concurrently. -// -// For example, a handler might have method "acme.foo.v1.FooService.Bar", -// service "acme.foo.v1.FooService", and package "acme.foo.v1". In that case, -// the newRequest constructor would be: -// func() proto.Message { -// return &foopb.BarRequest{} -// } +// NewHandler constructs a Handler. The supplied package, service, and method +// names must be protobuf identifiers. For example, a handler for the URL +// "/acme.foo.v1.FooService/Bar" would have package "acme.foo.v1", service +// "FooService", and method "Bar". // // Remember that NewHandler is usually called from generated code - most users // won't need to deal with protobuf identifiers directly. func NewHandler( - methodFQN, serviceFQN, packageFQN string, - newRequest func() proto.Message, - impl Func, + pkg, service, method string, + impl HandlerStreamFunc, opts ...HandlerOption, ) *Handler { - var cfg handlerCfg + cfg := handlerCfg{ + Package: pkg, + Service: service, + Method: method, + } for _, opt := range opts { opt.applyToHandler(&cfg) } if reg := cfg.Registrar; reg != nil { - reg.register(serviceFQN) + reg.register(cfg.Package, cfg.Service) } return &Handler{ - methodFQN: methodFQN, - serviceFQN: serviceFQN, - packageFQN: packageFQN, - newRequest: newRequest, - unary: impl, - config: cfg, + config: cfg, + implementation: impl, } } @@ -131,8 +105,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // To ensure that we can re-use connections, always consume and close the // request body. defer r.Body.Close() - defer io.Copy(ioutil.Discard, r.Body) + defer io.Copy(io.Discard, r.Body) + // TODO: verify HTTP/2 for bidirectional streaming + if false && r.ProtoMajor < 2 { + w.WriteHeader(http.StatusHTTPVersionNotSupported) + io.WriteString(w, "bidirectional streaming requires HTTP/2") + return + } if r.Method != http.MethodPost { // grpc-go returns a 500 here, but interoperability with non-gRPC HTTP // clients is better if we return a 405. @@ -142,13 +122,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } spec := &Specification{ - Method: h.methodFQN, - Service: h.serviceFQN, - Package: h.packageFQN, + Package: h.config.Package, + Service: h.config.Service, + Method: h.config.Method, Path: r.URL.Path, ContentType: r.Header.Get("Content-Type"), RequestCompression: CompressionIdentity, ResponseCompression: CompressionIdentity, + ReadMaxBytes: h.config.MaxRequestBytes, } if (spec.ContentType == TypeJSON || spec.ContentType == TypeProtoTwirp) && h.config.DisableTwirp { w.Header().Set("Accept-Post", acceptPostValueWithoutJSON) @@ -164,8 +145,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // We need to parse metadata before entering the interceptor stack, but we'd - // like any errors we encounter to be visible to interceptors for - // observability. We'll collect any such errors here and use them to + // like to report errors to the client in a format they understand (if + // possible). We'll collect any such errors here and use them to // short-circuit early later on. // // NB, future refactorings will need to take care to avoid typed nils here. @@ -234,8 +215,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - // We may write to the body in the implementation (e.g., reflection handler), so we should - // set headers here. + // We should write any remaining headers here, since: (a) the implementation + // may write to the body, thereby sending the headers, and (b) interceptors + // should be able to see this data. w.Header().Set("Content-Type", spec.ContentType) if spec.ContentType != TypeJSON && spec.ContentType != TypeProtoTwirp { w.Header().Set("Grpc-Accept-Encoding", acceptEncodingValue) @@ -246,190 +228,62 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("Trailer", "Grpc-Status-Details-Bin") } - ctx := NewHandlerContext(r.Context(), *spec, r.Header, w.Header()) - var unary Func - if failed != nil { - unary = Func(func(context.Context, proto.Message) (proto.Message, error) { - return nil, failed - }) - } else if spec.ContentType == TypeJSON || spec.ContentType == TypeProtoTwirp { - unary = h.implementationTwirp(w, r, spec) - } else { - unary = h.implementationGRPC(w, r, spec) - } - res, err := h.wrap(unary)(ctx, h.newRequest()) - h.writeResult(r.Context(), w, spec, res, err) -} - -func (h *Handler) implementationTwirp(w http.ResponseWriter, r *http.Request, spec *Specification) Func { - return Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - var body io.Reader = r.Body + // Unlike gRPC, Twirp manages compression using the standard HTTP mechanisms. + // Since they apply to the whole stream, it's easiest to handle it here. + var requestBody io.Reader = r.Body + if spec.ContentType == TypeJSON || spec.ContentType == TypeProtoTwirp { if spec.RequestCompression == CompressionGzip { - gr, err := gzip.NewReader(body) - if err != nil { - return nil, errorf(CodeInvalidArgument, "can't read gzipped body") + gr, err := gzip.NewReader(requestBody) + if err != nil && failed == nil { + failed = errorf(CodeInvalidArgument, "can't read gzipped body: %w", err) + } else if err == nil { + defer gr.Close() + requestBody = gr } - defer gr.Close() - body = gr } - if max := h.config.MaxRequestBytes; max > 0 { - body = &io.LimitedReader{ - R: body, - N: int64(max), - } - } - if spec.ContentType == TypeJSON { - if err := unmarshalJSON(body, req); err != nil { - return nil, wrap(CodeInvalidArgument, newMalformedError("can't unmarshal JSON body")) - } - } else { - if err := unmarshalTwirpProto(body, req); err != nil { - return nil, wrap(CodeInvalidArgument, newMalformedError("can't unmarshal Twirp protobuf body")) - } - } - return h.unary(ctx, req) - }) -} - -func (h *Handler) implementationGRPC(w http.ResponseWriter, r *http.Request, spec *Specification) Func { - return Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - if s := h.stream; s != nil { - s(ctx, w, r, spec.RequestCompression, spec.ResponseCompression, h.config.Hooks) - return nil, nil - } - if err := unmarshalLPM(r.Body, req, spec.RequestCompression, h.config.MaxRequestBytes); err != nil { - return nil, errorf(CodeInvalidArgument, "can't unmarshal protobuf body") + // Checking Content-Encoding ensures that some other user-supplied + // middleware isn't already compressing the response. + if spec.ResponseCompression == CompressionGzip && w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Encoding", "gzip") + gw := getGzipWriter(w) + defer putGzipWriter(gw) + w = &gzipResponseWriter{ResponseWriter: w, gw: gw} } - return h.unary(ctx, req) - }) -} - -func (h *Handler) writeResult(ctx context.Context, w http.ResponseWriter, spec *Specification, res proto.Message, err error) { - if spec.ContentType == TypeJSON || spec.ContentType == TypeProtoTwirp { - h.writeResultTwirp(ctx, w, spec, res, err) - return } - h.writeResultGRPC(ctx, w, spec, res, err) -} -func (h *Handler) writeResultTwirp(ctx context.Context, w http.ResponseWriter, spec *Specification, res proto.Message, err error) { - // Even if the client requested gzip compression, check Content-Encoding to - // make sure some other HTTP middleware hasn't already swapped out the - // ResponseWriter. - if spec.ResponseCompression == CompressionGzip && w.Header().Get("Content-Encoding") == "" { - w.Header().Set("Content-Encoding", "gzip") - gw := getGzipWriter(w) - defer putGzipWriter(gw) - w = &gzipResponseWriter{ResponseWriter: w, gw: gw} - } - if err != nil { - // Twirp always writes errors as JSON. - writeErrorJSON(ctx, w, err, h.config.Hooks) - return - } - if spec.ContentType == TypeJSON { - marshalJSON(ctx, w, res, h.config.Hooks) - } else { - marshalTwirpProto(ctx, w, res, h.config.Hooks) + stream := newServerStream( + w, + &readCloser{Reader: requestBody, Closer: r.Body}, + spec.ContentType, + h.config.MaxRequestBytes, + spec.ResponseCompression == CompressionGzip, + ) + ctx := NewHandlerContext(r.Context(), *spec, r.Header, w.Header()) + // TODO: refactor interceptors and apply them here + if failed != nil { + _ = stream.CloseReceive() + _ = stream.CloseSend(failed) } + h.implementation(ctx, stream) } -func (h *Handler) writeResultGRPC(ctx context.Context, w http.ResponseWriter, spec *Specification, res proto.Message, err error) { - if err != nil { - writeErrorGRPC(ctx, w, err, h.config.Hooks) - return - } - if err := marshalLPM(ctx, w, res, spec.ResponseCompression, 0 /* maxBytes */, h.config.Hooks); err != nil { - // It's safe to write gRPC errors even after we've started writing the - // body. - writeErrorGRPC(ctx, w, errorf(CodeUnknown, "can't marshal protobuf response"), h.config.Hooks) - return - } - writeErrorGRPC(ctx, w, nil, h.config.Hooks) +// Path returns the URL pattern to use when registering this handler. It's used +// by the generated code. +func (h *Handler) Path() string { + return fmt.Sprintf("/%s.%s/%s", h.config.Package, h.config.Service, h.config.Method) } -func (h *Handler) wrap(next Func) Func { - if h.config.Interceptor != nil { - return h.config.Interceptor.Wrap(next) - } - return next +// ServicePath returns the URL pattern for the protobuf service. It's used by +// the generated code. +func (h *Handler) ServicePath() string { + return fmt.Sprintf("/%s.%s/", h.config.Package, h.config.Service) } func splitOnCommasAndSpaces(c rune) bool { return c == ',' || c == ' ' } -func writeErrorJSON(ctx context.Context, w http.ResponseWriter, err error, hooks *Hooks) { - // Even if the caller sends TypeProtoTwirp, we respond with TypeJSON on errors. - w.Header().Set("Content-Type", TypeJSON) - s := newTwirpStatus(err) - bs, merr := json.Marshal(s) - if merr != nil { - hooks.onMarshalError(ctx, merr) - w.WriteHeader(http.StatusInternalServerError) - // codes don't need to be escaped in JSON, so this is okay - const tmpl = `{"code": "%s", "msg": "error marshaling error with code %s"}` - if _, nerr := fmt.Fprintf(w, tmpl, CodeInternal.twirp(), s.Code); nerr != nil { - hooks.onNetworkError(ctx, nerr) - } - return - } - w.WriteHeader(CodeOf(err).http()) - _, err = w.Write(bs) - if err != nil { - hooks.onNetworkError(ctx, err) - } -} - -func writeErrorGRPC(ctx context.Context, w http.ResponseWriter, err error, hooks *Hooks) { - if err == nil { - w.Header().Set("Grpc-Status", strconv.Itoa(int(CodeOK))) - w.Header().Set("Grpc-Message", "") - w.Header().Set("Grpc-Status-Details-Bin", "") - return - } - // gRPC errors are successes at the HTTP level and net/http automatically - // sends a 200 if we don't set a status code. Leaving the HTTP status - // implicit lets us use this function when we hit an error partway through - // writing the body. - s := statusFromError(err) - code := strconv.Itoa(int(s.Code)) - // If we ever need to send more trailers, make sure to declare them in the headers - // above. - if bin, err := proto.Marshal(s); err != nil { - w.Header().Set("Grpc-Status", strconv.Itoa(int(CodeInternal))) - w.Header().Set("Grpc-Message", percentEncode("error marshaling protobuf status with code "+code)) - hooks.onMarshalError(ctx, err) - } else { - w.Header().Set("Grpc-Status", code) - w.Header().Set("Grpc-Message", percentEncode(s.Message)) - w.Header().Set("Grpc-Status-Details-Bin", encodeBinaryHeader(bin)) - } -} - -func statusFromError(err error) *statuspb.Status { - s := &statuspb.Status{ - Code: int32(CodeUnknown), - Message: err.Error(), - } - if re, ok := AsError(err); ok { - s.Code = int32(re.Code()) - s.Details = re.Details() - if e := re.Unwrap(); e != nil { - s.Message = e.Error() // don't repeat code - } - } - return s -} - -func newTwirpStatus(err error) *twirp.Status { - gs := statusFromError(err) - s := &twirp.Status{ - Code: Code(gs.Code).twirp(), - Message: gs.Message, - } - if te, ok := asTwirpError(err); ok { - s.Code = te.TwirpCode() - } - return s +type readCloser struct { + io.Reader + io.Closer } diff --git a/header.go b/header.go index c5703f86..6efd7237 100644 --- a/header.go +++ b/header.go @@ -124,7 +124,6 @@ func percentDecodeSlow(encoded string, offset int) string { } parsed, err := strconv.ParseUint(encoded[i+1:i+3], 16 /* hex */, 8 /* bitsize */) if err != nil { - fmt.Println(err) out.WriteRune(utf8.RuneError) } else { out.WriteByte(byte(parsed)) diff --git a/health.go b/health.go index 5ba262a7..260b35e5 100644 --- a/health.go +++ b/health.go @@ -60,53 +60,69 @@ func NewChecker(reg *Registrar) func(context.Context, string) (HealthStatus, err // https://github.com/grpc/grpc/blob/master/doc/health-checking.md // https://github.com/grpc/grpc/blob/master/src/proto/grpc/health/v1/health.proto func NewHealthHandler( - check func(context.Context, string) (HealthStatus, error), + checker func(context.Context, string) (HealthStatus, error), opts ...HandlerOption, ) (string, *http.ServeMux) { - const packageFQN = "grpc.health.v1" - const serviceFQN = packageFQN + ".Health" - const checkFQN = serviceFQN + ".Check" - const watchFQN = serviceFQN + ".Watch" - const servicePath = "/" + serviceFQN + "/" - const checkPath = servicePath + "Check" - const watchPath = servicePath + "Watch" - mux := http.NewServeMux() - checkHandler := NewHandler( - checkFQN, serviceFQN, packageFQN, - func() proto.Message { return &healthpb.HealthCheckRequest{} }, - func(ctx context.Context, req proto.Message) (proto.Message, error) { - typed, ok := req.(*healthpb.HealthCheckRequest) - if !ok { - return nil, errorf( - CodeInternal, - "can't call %s/Check with a %v", - serviceFQN, - req.ProtoReflect().Descriptor().FullName(), - ) + + checkImplementation := Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { + typed, ok := req.(*healthpb.HealthCheckRequest) + if !ok { + return nil, errorf( + CodeInternal, + "can't call grpc.health.v1.Health.Check with a %v", + req.ProtoReflect().Descriptor().FullName(), + ) + } + status, err := checker(ctx, typed.Service) + if err != nil { + return nil, err + } + return &healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_ServingStatus(status), + }, nil + }) + + watchImplementation := HandlerStreamFunc(func(_ context.Context, stream Stream) { + defer stream.CloseReceive() + _ = stream.CloseSend(errorf( + CodeUnimplemented, + "reRPC doesn't support watching health state", + )) + }) + + if ic := ConfiguredHandlerInterceptor(opts...); ic != nil { + checkImplementation = ic.Wrap(checkImplementation) + // TODO: apply stream interceptor + } + + check := NewHandler( + "grpc.health.v1", "Health", "Check", + HandlerStreamFunc(func(ctx context.Context, stream Stream) { + defer stream.CloseReceive() + var req healthpb.HealthCheckRequest + if err := stream.Receive(&req); err != nil { + _ = stream.CloseSend(err) + return } - status, err := check(ctx, typed.Service) + res, err := checkImplementation(ctx, &req) if err != nil { - return nil, err + _ = stream.CloseSend(err) + return } - return &healthpb.HealthCheckResponse{ - Status: healthpb.HealthCheckResponse_ServingStatus(status), - }, nil - }, + _ = stream.CloseSend(stream.Send(res)) + }), opts..., ) - mux.Handle(checkPath, checkHandler) + mux.Handle(check.Path(), check) watch := NewHandler( - watchFQN, serviceFQN, packageFQN, - func() proto.Message { return &healthpb.HealthCheckRequest{} }, - func(ctx context.Context, req proto.Message) (proto.Message, error) { - return nil, errorf(CodeUnimplemented, "reRPC doesn't support watching health state") - }, + "grpc.health.v1", "Health", "Watch", + watchImplementation, opts..., ) - mux.Handle(watchPath, watch) - mux.Handle("/", NewBadRouteHandler(opts...)) + mux.Handle(watch.Path(), watch) - return servicePath, mux + mux.Handle("/", NewBadRouteHandler(opts...)) + return watch.ServicePath(), mux } diff --git a/interceptor.go b/interceptor.go index 66515a66..e746efee 100644 --- a/interceptor.go +++ b/interceptor.go @@ -27,6 +27,26 @@ type Interceptor interface { Wrap(Func) Func } +// ConfiguredCallInterceptor returns the Interceptor configured by a collection +// of call options (if any). It's used in generated code. +func ConfiguredCallInterceptor(opts ...CallOption) Interceptor { + var cfg callCfg + for _, o := range opts { + o.applyToCall(&cfg) + } + return cfg.Interceptor +} + +// ConfiguredHandlerInterceptor returns the Interceptor configured by a collection +// of handler options (if any). It's used in generated code. +func ConfiguredHandlerInterceptor(opts ...HandlerOption) Interceptor { + var cfg handlerCfg + for _, o := range opts { + o.applyToHandler(&cfg) + } + return cfg.Interceptor +} + // ShortCircuit builds an interceptor that doesn't call the wrapped Func. // Instead, it returns the supplied Error immediately. // diff --git a/interceptor_example_test.go b/interceptor_example_test.go index c6ed74c6..af4b2835 100644 --- a/interceptor_example_test.go +++ b/interceptor_example_test.go @@ -29,7 +29,7 @@ func ExampleCallMetadata() { client.Ping(context.Background(), &pingpb.PingRequest{}) // Output: - // calling internal.ping.v1test.PingService.Ping + // calling Ping } func ExampleChain() { diff --git a/internal/crosstest/cross_test.go b/internal/crosstest/cross_test.go index 65bce315..0b02ea19 100644 --- a/internal/crosstest/cross_test.go +++ b/internal/crosstest/cross_test.go @@ -138,7 +138,7 @@ func testWithReRPCClient(t *testing.T, client crosspb.CrossServiceClientReRPC) { _, err := client.Ping(ctx, req) rerr := assertErrorReRPC(t, err, "deadline exceeded error") assert.Equal(t, rerr.Code(), rerpc.CodeDeadlineExceeded, "error code") - assert.Equal(t, rerr.Error(), "DeadlineExceeded: context deadline exceeded", "error message") + assert.ErrorIs(t, rerr, context.DeadlineExceeded, "error unwraps to context.DeadlineExceeded") }) } diff --git a/internal/crosstest/v1test/cross_rerpc.pb.go b/internal/crosstest/v1test/cross_rerpc.pb.go index bf5db645..4194090a 100644 --- a/internal/crosstest/v1test/cross_rerpc.pb.go +++ b/internal/crosstest/v1test/cross_rerpc.pb.go @@ -30,8 +30,9 @@ type CrossServiceClientReRPC interface { } type crossServiceClientReRPC struct { - ping rerpc.Client - fail rerpc.Client + ping rerpc.Client + fail rerpc.Client + options []rerpc.CallOption } // NewCrossServiceClientReRPC constructs a client for the @@ -39,49 +40,102 @@ type crossServiceClientReRPC struct { // apply to all calls made with this client. // // The URL supplied here should be the base URL for the gRPC server (e.g., -// https://api.acme.com or https://acme.com/api/grpc). +// https://api.acme.com or https://acme.com/grpc). func NewCrossServiceClientReRPC(baseURL string, doer rerpc.Doer, opts ...rerpc.CallOption) CrossServiceClientReRPC { baseURL = strings.TrimRight(baseURL, "/") return &crossServiceClientReRPC{ ping: *rerpc.NewClient( doer, - baseURL+"/internal.crosstest.v1test.CrossService/Ping", // complete URL to call method - "internal.crosstest.v1test.CrossService.Ping", // fully-qualified protobuf method - "internal.crosstest.v1test.CrossService", // fully-qualified protobuf service - "internal.crosstest.v1test", // fully-qualified protobuf package - func() proto.Message { return &PingResponse{} }, // response constructor + baseURL, + "internal.crosstest.v1test", // protobuf package + "CrossService", // protobuf service + "Ping", // protobuf method opts..., ), fail: *rerpc.NewClient( doer, - baseURL+"/internal.crosstest.v1test.CrossService/Fail", // complete URL to call method - "internal.crosstest.v1test.CrossService.Fail", // fully-qualified protobuf method - "internal.crosstest.v1test.CrossService", // fully-qualified protobuf service - "internal.crosstest.v1test", // fully-qualified protobuf package - func() proto.Message { return &FailResponse{} }, // response constructor + baseURL, + "internal.crosstest.v1test", // protobuf package + "CrossService", // protobuf service + "Fail", // protobuf method opts..., ), + options: opts, } } // Ping calls internal.crosstest.v1test.CrossService.Ping. Call options passed // here apply only to this call. func (c *crossServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, opts ...rerpc.CallOption) (*PingResponse, error) { - res, err := c.ping.Call(ctx, req, opts...) + wrapped := rerpc.Func(func(ctx context.Context, msg proto.Message) (proto.Message, error) { + stream := c.ping.Call(ctx, opts...) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() + return nil, err + } + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + var res PingResponse + if err := stream.Receive(&res); err != nil { + _ = stream.CloseReceive() + return nil, err + } + return &res, stream.CloseReceive() + }) + mergedOpts := append([]rerpc.CallOption{}, c.options...) + mergedOpts = append(mergedOpts, opts...) + if ic := rerpc.ConfiguredCallInterceptor(mergedOpts...); ic != nil { + wrapped = ic.Wrap(wrapped) + } + res, err := wrapped(c.ping.Context(ctx, opts...), req) if err != nil { return nil, err } - return res.(*PingResponse), nil + typed, ok := res.(*PingResponse) + if !ok { + return nil, rerpc.Errorf(rerpc.CodeInternal, "expected response to be internal.crosstest.v1test.PingResponse, got %v", res.ProtoReflect().Descriptor().FullName()) + } + return typed, nil } // Fail calls internal.crosstest.v1test.CrossService.Fail. Call options passed // here apply only to this call. func (c *crossServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opts ...rerpc.CallOption) (*FailResponse, error) { - res, err := c.fail.Call(ctx, req, opts...) + wrapped := rerpc.Func(func(ctx context.Context, msg proto.Message) (proto.Message, error) { + stream := c.fail.Call(ctx, opts...) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() + return nil, err + } + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + var res FailResponse + if err := stream.Receive(&res); err != nil { + _ = stream.CloseReceive() + return nil, err + } + return &res, stream.CloseReceive() + }) + mergedOpts := append([]rerpc.CallOption{}, c.options...) + mergedOpts = append(mergedOpts, opts...) + if ic := rerpc.ConfiguredCallInterceptor(mergedOpts...); ic != nil { + wrapped = ic.Wrap(wrapped) + } + res, err := wrapped(c.fail.Context(ctx, opts...), req) if err != nil { return nil, err } - return res.(*FailResponse), nil + typed, ok := res.(*FailResponse) + if !ok { + return nil, rerpc.Errorf(rerpc.CodeInternal, "expected response to be internal.crosstest.v1test.FailResponse, got %v", res.ProtoReflect().Descriptor().FullName()) + } + return typed, nil } // CrossServiceReRPC is a server for the internal.crosstest.v1test.CrossService @@ -102,51 +156,84 @@ type CrossServiceReRPC interface { // handler. It returns the handler and the path on which to mount it. func NewCrossServiceHandlerReRPC(svc CrossServiceReRPC, opts ...rerpc.HandlerOption) (string, *http.ServeMux) { mux := http.NewServeMux() - + ic := rerpc.ConfiguredHandlerInterceptor(opts...) + + pingFunc := rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { + typed, ok := req.(*PingRequest) + if !ok { + return nil, rerpc.Errorf( + rerpc.CodeInternal, + "can't call internal.crosstest.v1test.CrossService.Ping with a %v", + req.ProtoReflect().Descriptor().FullName(), + ) + } + return svc.Ping(ctx, typed) + }) + if ic != nil { + pingFunc = ic.Wrap(pingFunc) + } ping := rerpc.NewHandler( - "internal.crosstest.v1test.CrossService.Ping", // fully-qualified protobuf method - "internal.crosstest.v1test.CrossService", // fully-qualified protobuf service - "internal.crosstest.v1test", // fully-qualified protobuf package - func() proto.Message { return &PingRequest{} }, // request msg constructor - rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - typed, ok := req.(*PingRequest) - if !ok { - return nil, rerpc.Errorf( - rerpc.CodeInternal, - "error in generated code: expected req to be a *PingRequest, got a %T", - req, - ) + "internal.crosstest.v1test", // protobuf package + "CrossService", // protobuf service + "Ping", // protobuf method + rerpc.HandlerStreamFunc(func(ctx context.Context, stream rerpc.Stream) { + defer stream.CloseReceive() + var req PingRequest + if err := stream.Receive(&req); err != nil { + _ = stream.CloseSend(err) + return + } + res, err := pingFunc(ctx, &req) + if err != nil { + _ = stream.CloseSend(err) + return } - return svc.Ping(ctx, typed) + _ = stream.CloseSend(stream.Send(res)) }), opts..., ) - mux.Handle("/internal.crosstest.v1test.CrossService/Ping", ping) - + mux.Handle(ping.Path(), ping) + + failFunc := rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { + typed, ok := req.(*FailRequest) + if !ok { + return nil, rerpc.Errorf( + rerpc.CodeInternal, + "can't call internal.crosstest.v1test.CrossService.Fail with a %v", + req.ProtoReflect().Descriptor().FullName(), + ) + } + return svc.Fail(ctx, typed) + }) + if ic != nil { + failFunc = ic.Wrap(failFunc) + } fail := rerpc.NewHandler( - "internal.crosstest.v1test.CrossService.Fail", // fully-qualified protobuf method - "internal.crosstest.v1test.CrossService", // fully-qualified protobuf service - "internal.crosstest.v1test", // fully-qualified protobuf package - func() proto.Message { return &FailRequest{} }, // request msg constructor - rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - typed, ok := req.(*FailRequest) - if !ok { - return nil, rerpc.Errorf( - rerpc.CodeInternal, - "error in generated code: expected req to be a *FailRequest, got a %T", - req, - ) + "internal.crosstest.v1test", // protobuf package + "CrossService", // protobuf service + "Fail", // protobuf method + rerpc.HandlerStreamFunc(func(ctx context.Context, stream rerpc.Stream) { + defer stream.CloseReceive() + var req FailRequest + if err := stream.Receive(&req); err != nil { + _ = stream.CloseSend(err) + return + } + res, err := failFunc(ctx, &req) + if err != nil { + _ = stream.CloseSend(err) + return } - return svc.Fail(ctx, typed) + _ = stream.CloseSend(stream.Send(res)) }), opts..., ) - mux.Handle("/internal.crosstest.v1test.CrossService/Fail", fail) + mux.Handle(fail.Path(), fail) // Respond to unknown protobuf methods with gRPC and Twirp's 404 equivalents. mux.Handle("/", rerpc.NewBadRouteHandler(opts...)) - return "/internal.crosstest.v1test.CrossService/", mux + return fail.ServicePath(), mux } var _ CrossServiceReRPC = (*UnimplementedCrossServiceReRPC)(nil) // verify interface implementation diff --git a/internal/ping/v1test/ping_rerpc.pb.go b/internal/ping/v1test/ping_rerpc.pb.go index 2a4820a7..f2d831aa 100644 --- a/internal/ping/v1test/ping_rerpc.pb.go +++ b/internal/ping/v1test/ping_rerpc.pb.go @@ -8,11 +8,10 @@ package pingpb import ( context "context" - http "net/http" - strings "strings" - rerpc "github.com/rerpc/rerpc" proto "google.golang.org/protobuf/proto" + http "net/http" + strings "strings" ) // This is a compile-time assertion to ensure that this generated file and the @@ -31,8 +30,9 @@ type PingServiceClientReRPC interface { } type pingServiceClientReRPC struct { - ping rerpc.Client - fail rerpc.Client + ping rerpc.Client + fail rerpc.Client + options []rerpc.CallOption } // NewPingServiceClientReRPC constructs a client for the @@ -40,49 +40,102 @@ type pingServiceClientReRPC struct { // all calls made with this client. // // The URL supplied here should be the base URL for the gRPC server (e.g., -// https://api.acme.com or https://acme.com/api/grpc). +// https://api.acme.com or https://acme.com/grpc). func NewPingServiceClientReRPC(baseURL string, doer rerpc.Doer, opts ...rerpc.CallOption) PingServiceClientReRPC { baseURL = strings.TrimRight(baseURL, "/") return &pingServiceClientReRPC{ ping: *rerpc.NewClient( doer, - baseURL+"/internal.ping.v1test.PingService/Ping", // complete URL to call method - "internal.ping.v1test.PingService.Ping", // fully-qualified protobuf method - "internal.ping.v1test.PingService", // fully-qualified protobuf service - "internal.ping.v1test", // fully-qualified protobuf package - func() proto.Message { return &PingResponse{} }, // response constructor + baseURL, + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "Ping", // protobuf method opts..., ), fail: *rerpc.NewClient( doer, - baseURL+"/internal.ping.v1test.PingService/Fail", // complete URL to call method - "internal.ping.v1test.PingService.Fail", // fully-qualified protobuf method - "internal.ping.v1test.PingService", // fully-qualified protobuf service - "internal.ping.v1test", // fully-qualified protobuf package - func() proto.Message { return &FailResponse{} }, // response constructor + baseURL, + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "Fail", // protobuf method opts..., ), + options: opts, } } // Ping calls internal.ping.v1test.PingService.Ping. Call options passed here // apply only to this call. func (c *pingServiceClientReRPC) Ping(ctx context.Context, req *PingRequest, opts ...rerpc.CallOption) (*PingResponse, error) { - res, err := c.ping.Call(ctx, req, opts...) + wrapped := rerpc.Func(func(ctx context.Context, msg proto.Message) (proto.Message, error) { + stream := c.ping.Call(ctx, opts...) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() + return nil, err + } + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + var res PingResponse + if err := stream.Receive(&res); err != nil { + _ = stream.CloseReceive() + return nil, err + } + return &res, stream.CloseReceive() + }) + mergedOpts := append([]rerpc.CallOption{}, c.options...) + mergedOpts = append(mergedOpts, opts...) + if ic := rerpc.ConfiguredCallInterceptor(mergedOpts...); ic != nil { + wrapped = ic.Wrap(wrapped) + } + res, err := wrapped(c.ping.Context(ctx, opts...), req) if err != nil { return nil, err } - return res.(*PingResponse), nil + typed, ok := res.(*PingResponse) + if !ok { + return nil, rerpc.Errorf(rerpc.CodeInternal, "expected response to be internal.ping.v1test.PingResponse, got %v", res.ProtoReflect().Descriptor().FullName()) + } + return typed, nil } // Fail calls internal.ping.v1test.PingService.Fail. Call options passed here // apply only to this call. func (c *pingServiceClientReRPC) Fail(ctx context.Context, req *FailRequest, opts ...rerpc.CallOption) (*FailResponse, error) { - res, err := c.fail.Call(ctx, req, opts...) + wrapped := rerpc.Func(func(ctx context.Context, msg proto.Message) (proto.Message, error) { + stream := c.fail.Call(ctx, opts...) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() + return nil, err + } + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + var res FailResponse + if err := stream.Receive(&res); err != nil { + _ = stream.CloseReceive() + return nil, err + } + return &res, stream.CloseReceive() + }) + mergedOpts := append([]rerpc.CallOption{}, c.options...) + mergedOpts = append(mergedOpts, opts...) + if ic := rerpc.ConfiguredCallInterceptor(mergedOpts...); ic != nil { + wrapped = ic.Wrap(wrapped) + } + res, err := wrapped(c.fail.Context(ctx, opts...), req) if err != nil { return nil, err } - return res.(*FailResponse), nil + typed, ok := res.(*FailResponse) + if !ok { + return nil, rerpc.Errorf(rerpc.CodeInternal, "expected response to be internal.ping.v1test.FailResponse, got %v", res.ProtoReflect().Descriptor().FullName()) + } + return typed, nil } // PingServiceReRPC is a server for the internal.ping.v1test.PingService @@ -103,51 +156,84 @@ type PingServiceReRPC interface { // handler. It returns the handler and the path on which to mount it. func NewPingServiceHandlerReRPC(svc PingServiceReRPC, opts ...rerpc.HandlerOption) (string, *http.ServeMux) { mux := http.NewServeMux() - + ic := rerpc.ConfiguredHandlerInterceptor(opts...) + + pingFunc := rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { + typed, ok := req.(*PingRequest) + if !ok { + return nil, rerpc.Errorf( + rerpc.CodeInternal, + "can't call internal.ping.v1test.PingService.Ping with a %v", + req.ProtoReflect().Descriptor().FullName(), + ) + } + return svc.Ping(ctx, typed) + }) + if ic != nil { + pingFunc = ic.Wrap(pingFunc) + } ping := rerpc.NewHandler( - "internal.ping.v1test.PingService.Ping", // fully-qualified protobuf method - "internal.ping.v1test.PingService", // fully-qualified protobuf service - "internal.ping.v1test", // fully-qualified protobuf package - func() proto.Message { return &PingRequest{} }, // request msg constructor - rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - typed, ok := req.(*PingRequest) - if !ok { - return nil, rerpc.Errorf( - rerpc.CodeInternal, - "error in generated code: expected req to be a *PingRequest, got a %T", - req, - ) + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "Ping", // protobuf method + rerpc.HandlerStreamFunc(func(ctx context.Context, stream rerpc.Stream) { + defer stream.CloseReceive() + var req PingRequest + if err := stream.Receive(&req); err != nil { + _ = stream.CloseSend(err) + return + } + res, err := pingFunc(ctx, &req) + if err != nil { + _ = stream.CloseSend(err) + return } - return svc.Ping(ctx, typed) + _ = stream.CloseSend(stream.Send(res)) }), opts..., ) - mux.Handle("/internal.ping.v1test.PingService/Ping", ping) - + mux.Handle(ping.Path(), ping) + + failFunc := rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { + typed, ok := req.(*FailRequest) + if !ok { + return nil, rerpc.Errorf( + rerpc.CodeInternal, + "can't call internal.ping.v1test.PingService.Fail with a %v", + req.ProtoReflect().Descriptor().FullName(), + ) + } + return svc.Fail(ctx, typed) + }) + if ic != nil { + failFunc = ic.Wrap(failFunc) + } fail := rerpc.NewHandler( - "internal.ping.v1test.PingService.Fail", // fully-qualified protobuf method - "internal.ping.v1test.PingService", // fully-qualified protobuf service - "internal.ping.v1test", // fully-qualified protobuf package - func() proto.Message { return &FailRequest{} }, // request msg constructor - rerpc.Func(func(ctx context.Context, req proto.Message) (proto.Message, error) { - typed, ok := req.(*FailRequest) - if !ok { - return nil, rerpc.Errorf( - rerpc.CodeInternal, - "error in generated code: expected req to be a *FailRequest, got a %T", - req, - ) + "internal.ping.v1test", // protobuf package + "PingService", // protobuf service + "Fail", // protobuf method + rerpc.HandlerStreamFunc(func(ctx context.Context, stream rerpc.Stream) { + defer stream.CloseReceive() + var req FailRequest + if err := stream.Receive(&req); err != nil { + _ = stream.CloseSend(err) + return + } + res, err := failFunc(ctx, &req) + if err != nil { + _ = stream.CloseSend(err) + return } - return svc.Fail(ctx, typed) + _ = stream.CloseSend(stream.Send(res)) }), opts..., ) - mux.Handle("/internal.ping.v1test.PingService/Fail", fail) + mux.Handle(fail.Path(), fail) // Respond to unknown protobuf methods with gRPC and Twirp's 404 equivalents. mux.Handle("/", rerpc.NewBadRouteHandler(opts...)) - return "/internal.ping.v1test.PingService/", mux + return fail.ServicePath(), mux } var _ PingServiceReRPC = (*UnimplementedPingServiceReRPC)(nil) // verify interface implementation diff --git a/message.go b/message.go index 3c0e8fbf..4739da45 100644 --- a/message.go +++ b/message.go @@ -3,12 +3,10 @@ package rerpc import ( "bytes" "compress/gzip" - "context" "encoding/binary" "errors" "fmt" "io" - "io/ioutil" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -21,123 +19,148 @@ var ( sizeZeroPrefix = make([]byte, 4) ) -func marshalJSON(ctx context.Context, w io.Writer, msg proto.Message, hooks *Hooks) { +type marshaler struct { + w io.Writer + ctype string + gzipGRPC bool +} + +func (m *marshaler) Marshal(msg proto.Message) error { + switch m.ctype { + case TypeJSON: + return m.marshalTwirpJSON(msg) + case TypeProtoTwirp: + return m.marshalTwirpProto(msg) + case TypeDefaultGRPC, TypeProtoGRPC: + return m.marshalGRPC(msg) + default: + return fmt.Errorf("unsupported Content-Type %q", m.ctype) + } +} + +func (m *marshaler) marshalTwirpJSON(msg proto.Message) error { bs, err := jsonpbMarshaler.Marshal(msg) if err != nil { - hooks.onMarshalError(ctx, fmt.Errorf("couldn't marshal protobuf message: %w", err)) - return - } - if _, err := w.Write(bs); err != nil { - hooks.onNetworkError(ctx, fmt.Errorf("couldn't write JSON: %w", err)) - return + return err } + _, err = m.w.Write(bs) + return err } -func marshalTwirpProto(ctx context.Context, w io.Writer, msg proto.Message, hooks *Hooks) { +func (m *marshaler) marshalTwirpProto(msg proto.Message) error { bs, err := proto.Marshal(msg) if err != nil { - hooks.onMarshalError(ctx, fmt.Errorf("couldn't marshal protobuf message: %w", err)) - return - } - if _, err := w.Write(bs); err != nil { - hooks.onNetworkError(ctx, fmt.Errorf("couldn't write Twirp protobuf: %w", err)) - return + return err } + _, err = m.w.Write(bs) + return err } -func unmarshalJSON(r io.Reader, msg proto.Message) error { - bs, err := io.ReadAll(r) +func (m *marshaler) marshalGRPC(msg proto.Message) error { + raw, err := proto.Marshal(msg) if err != nil { - return fmt.Errorf("can't read data: %w", err) + return fmt.Errorf("couldn't marshal protobuf message: %w", err) } - if len(bs) == 0 { - // zero value request + if !m.gzipGRPC { + if err := m.writeGRPCPrefix(false, len(raw)); err != nil { + return err // already enriched + } + if _, err := m.w.Write(raw); err != nil { + return fmt.Errorf("couldn't write message of length-prefixed message: %w", err) + } return nil } - if err := jsonpbUnmarshaler.Unmarshal(bs, msg); err != nil { - return fmt.Errorf("can't unmarshal JSON data into type %T: %w", msg, err) + data := getBuffer() + defer putBuffer(data) + gw := getGzipWriter(data) + defer putGzipWriter(gw) + + if _, err = gw.Write(raw); err != nil { // returns uncompressed size, which isn't useful + return fmt.Errorf("couldn't gzip data: %w", err) + } + if err := gw.Close(); err != nil { + return fmt.Errorf("couldn't close gzip writer: %w", err) + } + if err := m.writeGRPCPrefix(true, data.Len()); err != nil { + return err // already enriched + } + if _, err := io.Copy(m.w, data); err != nil { + return fmt.Errorf("couldn't write message of length-prefixed message: %w", err) } return nil } -func unmarshalTwirpProto(r io.Reader, msg proto.Message) error { - bs, err := io.ReadAll(r) - if err != nil { - return fmt.Errorf("can't read data: %w", err) - } - if len(bs) == 0 { - // zero value - return nil +func (m *marshaler) writeGRPCPrefix(compressed bool, size int) error { + prefixes := [5]byte{} + if compressed { + prefixes[0] = 1 } - if err := proto.Unmarshal(bs, msg); err != nil { - return fmt.Errorf("can't unmarshal protobuf data into type %T: %w", msg, err) + binary.BigEndian.PutUint32(prefixes[1:5], uint32(size)) + if _, err := m.w.Write(prefixes[:]); err != nil { + return fmt.Errorf("couldn't write prefix of length-prefixed message: %w", err) } return nil } -func marshalLPM(ctx context.Context, w io.Writer, msg proto.Message, compression string, maxBytes int, hooks *Hooks) error { - raw, err := proto.Marshal(msg) - if err != nil { - err = fmt.Errorf("couldn't marshal protobuf message: %w", err) - hooks.onMarshalError(ctx, err) - return err - } - data := &bytes.Buffer{} - var dataW io.Writer = data - switch compression { - case CompressionIdentity: - case CompressionGzip: - dataW = gzip.NewWriter(data) +type unmarshaler struct { + r io.Reader + ctype string + max int64 +} + +func (u *unmarshaler) Unmarshal(msg proto.Message) error { + switch u.ctype { + case TypeJSON: + return u.unmarshalTwirpJSON(msg) + case TypeProtoTwirp: + return u.unmarshalTwirpProto(msg) + case TypeDefaultGRPC, TypeProtoGRPC: + return u.unmarshalGRPC(msg) default: - err := fmt.Errorf("unsupported length-prefixed message compression %q", compression) - hooks.onInternalError(ctx, err) - return err - } - _, err = dataW.Write(raw) // returns uncompressed size, which isn't useful - if err != nil { - err = fmt.Errorf("couldn't compress with %q: %w", compression, err) - hooks.onInternalError(ctx, err) - return err - } - if c, ok := dataW.(io.Closer); ok { - if err := c.Close(); err != nil { - err = fmt.Errorf("couldn't close writer with compression %q: %w", compression, err) - hooks.onInternalError(ctx, err) - return err - } + return fmt.Errorf("unsupported Content-Type %q", u.ctype) } +} - size := data.Len() - if maxBytes > 0 && size > maxBytes { - return fmt.Errorf("message too large: got %d bytes, max is %d", size, maxBytes) - } - prefixes := [5]byte{} - if compression == CompressionIdentity { - prefixes[0] = 0 - } else { - prefixes[0] = 1 - } - binary.BigEndian.PutUint32(prefixes[1:5], uint32(size)) +func (u *unmarshaler) unmarshalTwirpJSON(msg proto.Message) error { + return u.unmarshalTwirp(msg, "JSON", jsonpbUnmarshaler.Unmarshal) +} - if _, err := w.Write(prefixes[:]); err != nil { - err = fmt.Errorf("couldn't write prefix of length-prefixed message: %w", err) - hooks.onNetworkError(ctx, err) - return err +func (u *unmarshaler) unmarshalTwirpProto(msg proto.Message) error { + return u.unmarshalTwirp(msg, "protobuf", proto.Unmarshal) +} + +func (u *unmarshaler) unmarshalTwirp(msg proto.Message, variant string, do func([]byte, proto.Message) error) error { + buf := getBuffer() + defer putBuffer(buf) + r := u.r + if u.max > 0 { + r = &io.LimitedReader{ + R: r, + N: int64(u.max), + } } - if _, err := io.Copy(w, data); err != nil { - err = fmt.Errorf("couldn't write data portion of length-prefixed message: %w", err) - hooks.onNetworkError(ctx, err) - return err + if n, err := buf.ReadFrom(u.r); err != nil { + return wrap(CodeUnknown, err) + } else if n == 0 { + return nil // zero value + } + if err := do(buf.Bytes(), msg); err != nil { + if lr, ok := r.(*io.LimitedReader); ok && lr.N <= 0 { + // likely more informative than unmarshaling error + return errorf(CodeUnknown, "request too large: max bytes set to %v", u.max) + } + fqn := msg.ProtoReflect().Descriptor().FullName() + return newMalformedError(fmt.Sprintf("can't unmarshal %s into %v: %v", variant, fqn, err)) } return nil } -func unmarshalLPM(r io.Reader, msg proto.Message, compression string, maxBytes int) error { +func (u *unmarshaler) unmarshalGRPC(msg proto.Message) error { // Each length-prefixed message starts with 5 bytes of metadata: a one-byte // unsigned integer indicating whether the payload is compressed, and a // four-byte unsigned integer indicating the message length. prefixes := make([]byte, 5) - n, err := r.Read(prefixes) + n, err := u.r.Read(prefixes) if err != nil && errors.Is(err, io.EOF) && n == 5 && bytes.Equal(prefixes[1:5], sizeZeroPrefix) { // Successfully read prefix, expect no additional data, and got an EOF, so // there's nothing left to do - the zero value of the msg is correct. @@ -148,34 +171,32 @@ func unmarshalLPM(r io.Reader, msg proto.Message, compression string, maxBytes i return fmt.Errorf("gRPC protocol error: missing length-prefixed message metadata: %w", err) } + // TODO: grpc-web uses the MSB of this byte to indicate that the LPM contains + // trailers. var compressed bool switch prefixes[0] { case 0: compressed = false - if compression != CompressionIdentity { - return fmt.Errorf("gRPC protocol error: protobuf is uncompressed but message compression is %q", compression) - } case 1: compressed = true - if compression == CompressionIdentity { - return errors.New("gRPC protocol error: protobuf is compressed but message should be uncompressed") - } default: return fmt.Errorf("gRPC protocol error: length-prefixed message has invalid compressed flag %v", prefixes[0]) } size := int(binary.BigEndian.Uint32(prefixes[1:5])) if size < 0 { - return fmt.Errorf("message size %d overflows uint32", size) - } - if maxBytes > 0 && size > maxBytes { - return fmt.Errorf("message too large: got %d bytes, max is %d", size, maxBytes) + return fmt.Errorf("message size %d overflowed uint32", size) + } else if u.max > 0 && int64(size) > u.max { + return fmt.Errorf("message size %d is larger than configured max %d", size, u.max) } + buf := getBuffer() + defer putBuffer(buf) - raw := make([]byte, size) + buf.Grow(size) + raw := buf.Bytes()[0:size] if size > 0 { - n, err = r.Read(raw) - if err != nil && err != io.EOF { + n, err = u.r.Read(raw) + if err != nil && !errors.Is(err, io.EOF) { return fmt.Errorf("error reading length-prefixed message data: %w", err) } if n < size { @@ -183,21 +204,23 @@ func unmarshalLPM(r io.Reader, msg proto.Message, compression string, maxBytes i } } - if compressed && compression == CompressionGzip { + if compressed { gr, err := gzip.NewReader(bytes.NewReader(raw)) if err != nil { return fmt.Errorf("can't decompress gzipped data: %w", err) } defer gr.Close() - decompressed, err := ioutil.ReadAll(gr) - if err != nil { + decompressed := getBuffer() + defer putBuffer(decompressed) + if _, err := decompressed.ReadFrom(gr); err != nil { return fmt.Errorf("can't decompress gzipped data: %w", err) } - raw = decompressed + raw = decompressed.Bytes() } if err := proto.Unmarshal(raw, msg); err != nil { - return fmt.Errorf("can't unmarshal data into type %T: %w", msg, err) + fqn := msg.ProtoReflect().Descriptor().FullName() + return fmt.Errorf("can't unmarshal protobuf into %v: %w", fqn, err) } return nil diff --git a/metadata.go b/metadata.go index 4342a4b2..e771eae8 100644 --- a/metadata.go +++ b/metadata.go @@ -17,14 +17,15 @@ const ( // Note that the Method, Service, and Package are fully-qualified protobuf // names, not Go import paths or identifiers. type Specification struct { - Method string // full protobuf name, e.g. "acme.foo.v1.FooService.Bar" - Service string // full protobuf name, e.g. "acme.foo.v1.FooService" - Package string // full protobuf name, e.g. "acme.foo.v1" + Package string // protobuf name, e.g. "acme.foo.v1" + Service string // protobuf name, e.g. "FooService" + Method string // protobuf name, e.g. "Bar" Path string ContentType string RequestCompression string ResponseCompression string + ReadMaxBytes int64 } // CallMetadata provides a Specification and access to request and response diff --git a/option.go b/option.go index 24028217..ebbaeae5 100644 --- a/option.go +++ b/option.go @@ -8,11 +8,9 @@ type Option interface { } type readMaxBytes struct { - Max int + Max int64 } -var _ Option = (*readMaxBytes)(nil) - // ReadMaxBytes limits the performance impact of pathologically large messages // sent by the other party. For handlers, ReadMaxBytes sets the maximum // allowable request size. For clients, ReadMaxBytes sets the maximum allowable @@ -20,7 +18,7 @@ var _ Option = (*readMaxBytes)(nil) // // Setting ReadMaxBytes to zero allows any request size. Both clients and // handlers default to allowing any request size. -func ReadMaxBytes(n int) Option { +func ReadMaxBytes(n int64) Option { return &readMaxBytes{n} } @@ -36,8 +34,6 @@ type gzipOption struct { Enable bool } -var _ Option = (*gzipOption)(nil) - // Gzip configures client and server compression strategies. // // For handlers, enabling gzip sends compressed responses to clients that @@ -67,8 +63,6 @@ type interceptOption struct { interceptor Interceptor } -var _ Option = (*interceptOption)(nil) - // Intercept configures a client or handler to use the supplied Interceptor. // Note that this Option replaces any previously-configured Interceptor - to // compose Interceptors, use a Chain. diff --git a/reflection.go b/reflection.go index 55bf6440..7f18ff7b 100644 --- a/reflection.go +++ b/reflection.go @@ -24,6 +24,8 @@ type Registrar struct { services map[string]struct{} } +var _ HandlerOption = (*Registrar)(nil) + // NewRegistrar constructs an empty Registrar. func NewRegistrar() *Registrar { return &Registrar{services: make(map[string]struct{})} @@ -53,15 +55,16 @@ func (r *Registrar) IsRegistered(service string) bool { return ok } -// Registers a fully-qualified protobuf service name. Safe to call +// Registers a protobuf package and service combination. Safe to call // concurrently. -func (r *Registrar) register(service string) { - if service == "" { +func (r *Registrar) register(pkg, service string) { + if pkg == "" || service == "" { // Typically BadRouteHandler. return } + fqn := pkg + "." + service r.mu.Lock() - r.services[service] = struct{}{} + r.services[fqn] = struct{}{} r.mu.Unlock() } @@ -83,67 +86,67 @@ func (r *Registrar) applyToHandler(cfg *handlerCfg) { // https://github.com/grpc/grpc-go/blob/master/Documentation/server-reflection-tutorial.md // https://github.com/grpc/grpc/blob/master/doc/server-reflection.md // https://github.com/fullstorydev/grpcurl -func NewReflectionHandler(reg *Registrar) (string, *http.ServeMux) { - const packageFQN = "grpc.reflection.v1alpha" - const serviceFQN = packageFQN + ".ServerReflection" +func NewReflectionHandler(reg *Registrar, opts ...HandlerOption) (string, *http.ServeMux) { + const pkg = "grpc.reflection.v1alpha" + const service = "ServerReflection" const method = "ServerReflectionInfo" - const methodFQN = serviceFQN + "." + method - const servicePath = "/" + serviceFQN + "/" - const methodPath = servicePath + method - reg.register(serviceFQN) + opts = append(opts, reg, ServeTwirp(false)) // no reflection in Twirp + svc := &reflectionServer{reg} + wrapped := HandlerStreamFunc(svc.Serve) + if i := ConfiguredHandlerInterceptor(opts...); i != nil { + fmt.Println("TODO: apply interceptors") + } h := NewHandler( - methodFQN, serviceFQN, packageFQN, - func() proto.Message { return nil }, - nil, // no unary implementation - ServeTwirp(false), // no reflection in Twirp + pkg, service, method, + wrapped, + opts..., ) - raw := &rawReflectionHandler{reg} - h.stream = raw.rawGRPC - mux := http.NewServeMux() - mux.Handle(methodPath, h) + mux.Handle(h.Path(), h) mux.Handle("/", NewBadRouteHandler()) - return servicePath, mux + return h.ServicePath(), mux } -type rawReflectionHandler struct { +type reflectionServer struct { reg *Registrar } -func (rh *rawReflectionHandler) rawGRPC(ctx context.Context, w http.ResponseWriter, r *http.Request, requestCompression, responseCompression string, hooks *Hooks) { - if r.ProtoMajor < 2 { - w.WriteHeader(http.StatusHTTPVersionNotSupported) - io.WriteString(w, "bidirectional streaming requires HTTP/2") - return - } +func (rs *reflectionServer) Serve(ctx context.Context, stream Stream) { + defer stream.CloseReceive() for { + if err := ctx.Err(); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + _ = stream.CloseSend(wrap(CodeDeadlineExceeded, err)) + } else if errors.Is(err, context.Canceled) { + _ = stream.CloseSend(wrap(CodeCanceled, err)) + } else { + _ = stream.CloseSend(wrap(CodeUnknown, err)) + } + return + } var req rpb.ServerReflectionRequest - if err := unmarshalLPM(r.Body, &req, requestCompression, 0); err != nil && errors.Is(err, io.EOF) { - writeErrorGRPC(ctx, w, nil, hooks) + if err := stream.Receive(&req); err != nil && errors.Is(err, io.EOF) { + _ = stream.CloseSend(nil) return } else if err != nil { - writeErrorGRPC(ctx, w, errorf(CodeUnknown, "can't unmarshal protobuf"), hooks) + _ = stream.CloseSend(err) return } - res, serr := rh.serve(&req) - if serr != nil { - writeErrorGRPC(ctx, w, serr, hooks) + res, err := rs.serve(&req) + if err != nil { + _ = stream.CloseSend(err) return } - if err := marshalLPM(ctx, w, res, responseCompression, 0, hooks); err != nil { - writeErrorGRPC(ctx, w, errorf(CodeUnknown, "can't marshal protobuf"), hooks) + if err := stream.Send(res); err != nil { + _ = stream.CloseSend(err) return } - - if f, ok := w.(http.Flusher); ok { - f.Flush() - } } } -func (rh *rawReflectionHandler) serve(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, *Error) { +func (rs *reflectionServer) serve(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, *Error) { // The grpc-go implementation of server reflection uses the APIs from // github.com/google/protobuf, which makes the logic fairly complex. The new // google.golang.org/protobuf/reflect/protoregistry exposes a higher-level @@ -219,7 +222,7 @@ func (rh *rawReflectionHandler) serve(req *rpb.ServerReflectionRequest) (*rpb.Se } } case *rpb.ServerReflectionRequest_ListServices: - services := rh.reg.Services() + services := rs.reg.Services() serviceResponses := make([]*rpb.ServiceResponse, len(services)) for i, n := range services { serviceResponses[i] = &rpb.ServiceResponse{ diff --git a/rerpc_test.go b/rerpc_test.go index ca390df9..0e2b7628 100644 --- a/rerpc_test.go +++ b/rerpc_test.go @@ -305,54 +305,41 @@ func TestServerProtoGRPC(t *testing.T) { assert.True(t, reg.IsRegistered(pingFQN), "ping service registered") assert.False(t, reg.IsRegistered(unknown), "unknown service registered") - healthPFQN := "grpc.health.v1" - healthFQN := healthPFQN + ".Health" - callCheck := func(req *healthpb.HealthCheckRequest, opts ...rerpc.CallOption) (*healthpb.HealthCheckResponse, error) { - client := rerpc.NewClient( - doer, - url+"/grpc.health.v1.Health/Check", - healthFQN+".Check", - healthFQN, - healthPFQN, - func() proto.Message { return &healthpb.HealthCheckResponse{} }, - ) - res, err := client.Call(context.Background(), req, opts...) - if err != nil { + callHealth := func(method string, req *healthpb.HealthCheckRequest, opts ...rerpc.CallOption) (*healthpb.HealthCheckResponse, error) { + client := rerpc.NewClient(doer, url, "grpc.health.v1", "Health", method) + stream := client.Call(context.Background(), opts...) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() return nil, err } - return res.(*healthpb.HealthCheckResponse), nil - } - callWatch := func(req *healthpb.HealthCheckRequest, opts ...rerpc.CallOption) (*healthpb.HealthCheckResponse, error) { - client := rerpc.NewClient( - doer, - url+"/grpc.health.v1.Health/Watch", - healthFQN+".Watch", - healthFQN, - healthPFQN, - func() proto.Message { return &healthpb.HealthCheckResponse{} }, - ) - res, err := client.Call(context.Background(), req, opts...) - if err != nil { + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + var res healthpb.HealthCheckResponse + if err := stream.Receive(&res); err != nil { + _ = stream.CloseReceive() return nil, err } - return res.(*healthpb.HealthCheckResponse), nil + return &res, stream.CloseReceive() } t.Run("process", func(t *testing.T) { req := &healthpb.HealthCheckRequest{} - res, err := callCheck(req, opts...) + res, err := callHealth("Check", req, opts...) assert.Nil(t, err, "rpc error") assert.Equal(t, rerpc.HealthStatus(res.Status), rerpc.HealthServing, "status") }) t.Run("known", func(t *testing.T) { req := &healthpb.HealthCheckRequest{Service: pingFQN} - res, err := callCheck(req, opts...) + res, err := callHealth("Check", req, opts...) assert.Nil(t, err, "rpc error") assert.Equal(t, rerpc.HealthStatus(res.Status), rerpc.HealthServing, "status") }) t.Run("unknown", func(t *testing.T) { req := &healthpb.HealthCheckRequest{Service: unknown} - _, err := callCheck(req, opts...) + _, err := callHealth("Check", req, opts...) assert.NotNil(t, err, "rpc error") rerr, ok := rerpc.AsError(err) assert.True(t, ok, "convert to rerpc error") @@ -360,7 +347,7 @@ func TestServerProtoGRPC(t *testing.T) { }) t.Run("watch", func(t *testing.T) { req := &healthpb.HealthCheckRequest{Service: pingFQN} - _, err := callWatch(req, opts...) + _, err := callHealth("Watch", req, opts...) assert.NotNil(t, err, "rpc error") rerr, ok := rerpc.AsError(err) assert.True(t, ok, "convert to rerpc error") @@ -369,9 +356,6 @@ func TestServerProtoGRPC(t *testing.T) { }) } testReflection := func(t *testing.T, url string, doer rerpc.Doer, opts ...rerpc.CallOption) { - const reflectPFQN = "grpc.reflection.v1alpha" - const reflectSFQN = reflectPFQN + ".ServerReflection" - const reflectFQN = reflectSFQN + ".ServerReflectionInfo" pingRequestFQN := string((&pingpb.PingRequest{}).ProtoReflect().Descriptor().FullName()) assert.Equal(t, reg.Services(), []string{ "grpc.health.v1.Health", @@ -380,19 +364,23 @@ func TestServerProtoGRPC(t *testing.T) { }, "services registered in memory") callReflect := func(req *reflectionpb.ServerReflectionRequest, opts ...rerpc.CallOption) (*reflectionpb.ServerReflectionResponse, error) { - client := rerpc.NewClient( - doer, - url+"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", - reflectFQN, - reflectSFQN, - reflectPFQN, - func() proto.Message { return &reflectionpb.ServerReflectionResponse{} }, - ) - res, err := client.Call(context.Background(), req, opts...) - if err != nil { + client := rerpc.NewClient(doer, url, "grpc.reflection.v1alpha", "ServerReflection", "ServerReflectionInfo") + stream := client.Call(context.Background(), opts...) + if err := stream.Send(req); err != nil { + _ = stream.CloseSend(err) + _ = stream.CloseReceive() + return nil, err + } + if err := stream.CloseSend(nil); err != nil { + _ = stream.CloseReceive() + return nil, err + } + var res reflectionpb.ServerReflectionResponse + if err := stream.Receive(&res); err != nil { + _ = stream.CloseReceive() return nil, err } - return res.(*reflectionpb.ServerReflectionResponse), nil + return &res, stream.CloseReceive() } t.Run("list_services", func(t *testing.T) { req := &reflectionpb.ServerReflectionRequest{ diff --git a/stream.go b/stream.go new file mode 100644 index 00000000..edf324f7 --- /dev/null +++ b/stream.go @@ -0,0 +1,419 @@ +package rerpc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "google.golang.org/protobuf/proto" + + statuspb "github.com/rerpc/rerpc/internal/status/v1" + "github.com/rerpc/rerpc/internal/twirp" +) + +type HandlerStreamFunc func(context.Context, Stream) +type CallStreamFunc func(context.Context) Stream + +// Stream is a bidirectional stream of protobuf messages. Streams aren't +// guaranteed to be safe for concurrent use. +type Stream interface { + Send(proto.Message) error + CloseSend(error) error + Receive(proto.Message) error + CloseReceive() error +} + +type exchange struct { + msg proto.Message + errs chan error +} + +type clientStream struct { + ctx context.Context + doer Doer + url string + maxReadBytes int64 + + // send goroutine + sendCh chan *exchange + sendClosed chan struct{} + writer *io.PipeWriter + marshaler marshaler + + // receive goroutine + reader *io.PipeReader + response *http.Response + responseErr error + responseReady chan struct{} + unmarshaler unmarshaler +} + +var _ Stream = (*clientStream)(nil) + +func newClientStream( + ctx context.Context, + doer Doer, + url string, + maxReadBytes int64, + gzipRequest bool, +) *clientStream { + pr, pw := io.Pipe() + stream := clientStream{ + ctx: ctx, + doer: doer, + url: url, + maxReadBytes: maxReadBytes, + sendCh: make(chan *exchange), + sendClosed: make(chan struct{}), + writer: pw, + marshaler: marshaler{w: pw, ctype: TypeDefaultGRPC, gzipGRPC: gzipRequest}, + reader: pr, + responseReady: make(chan struct{}), + } + go stream.startSendLoop() + requestPrepared := make(chan struct{}) + go stream.makeRequest(requestPrepared) + <-requestPrepared + return &stream +} + +func (cs *clientStream) Send(msg proto.Message) error { + errs := make(chan error) + cs.sendCh <- &exchange{msg, errs} + return <-errs +} + +func (cs *clientStream) CloseSend(_ error) error { + close(cs.sendCh) + <-cs.sendClosed + if err := cs.writer.Close(); err != nil { + return wrap(CodeUnknown, err) + } + return nil +} + +func (cs *clientStream) startSendLoop() { + defer close(cs.sendClosed) + for xc := range cs.sendCh { + if err := cs.marshaler.Marshal(xc.msg); err != nil { + // If the server has already sent us an error (or the request has failed + // in some other way), we'll get that error here. + if _, ok := AsError(err); !ok { + err = wrap(CodeUnknown, err) + } + xc.errs <- err + } else { + xc.errs <- nil + } + } +} + +func (cs *clientStream) Receive(msg proto.Message) error { + <-cs.responseReady + if cs.responseErr != nil { + return cs.responseErr + } + err := cs.unmarshaler.Unmarshal(msg) + if err != nil { + // If we can't read this LPM, see if the server sent an explicit error in + // trailers. First, we need to read the body to EOF. + io.Copy(io.Discard, cs.response.Body) + if serverErr := extractError(cs.response.Trailer); serverErr != nil { + cs.setResponseError(serverErr) + return serverErr + } + cs.setResponseError(wrap(CodeUnknown, err)) + return cs.responseErr + } + return nil +} + +func (cs *clientStream) CloseReceive() error { + <-cs.responseReady + if cs.response == nil { + return nil + } + io.Copy(io.Discard, cs.response.Body) + if err := cs.response.Body.Close(); err != nil { + return wrap(CodeUnknown, err) + } + return nil +} + +func (cs *clientStream) makeRequest(prepared chan struct{}) { + defer close(cs.responseReady) + + md, ok := CallMeta(cs.ctx) + if !ok { + cs.setResponseError(errorf(CodeInternal, "no call metadata available on context")) + close(prepared) + return + } + + if deadline, ok := cs.ctx.Deadline(); ok { + untilDeadline := time.Until(deadline) + if untilDeadline <= 0 { + cs.setResponseError(errorf(CodeDeadlineExceeded, "no time to make RPC: timeout is %v", untilDeadline)) + close(prepared) + return + } + if enc, err := encodeTimeout(untilDeadline); err == nil { + // Tests verify that the error in encodeTimeout is unreachable, so we + // should be safe without observability for the error case. + md.req.raw.Set("Grpc-Timeout", enc) + } + } + + req, err := http.NewRequestWithContext(cs.ctx, http.MethodPost, cs.url, cs.reader) + if err != nil { + cs.setResponseError(wrap(CodeUnknown, err)) + close(prepared) + return + } + req.Header = md.req.raw + + // Before we send off a request, check if we're already out of time. + if err := cs.ctx.Err(); err != nil { + code := CodeUnknown + if errors.Is(err, context.Canceled) { + code = CodeCanceled + } + if errors.Is(err, context.DeadlineExceeded) { + code = CodeDeadlineExceeded + } + cs.setResponseError(wrap(code, err)) + close(prepared) + return + } + + close(prepared) + res, err := cs.doer.Do(req) + if err != nil { + // Error message comes from our networking stack, so it's safe to expose. + code := CodeUnknown + if errors.Is(err, context.Canceled) { + code = CodeCanceled + } + if errors.Is(err, context.DeadlineExceeded) { + code = CodeDeadlineExceeded + } + cs.setResponseError(wrap(code, err)) + return + } + *md.res = NewImmutableHeader(res.Header) + + if res.StatusCode != http.StatusOK { + code := CodeUnknown + if c, ok := httpToGRPC[res.StatusCode]; ok { + code = c + } + cs.setResponseError(errorf(code, "HTTP status %v", res.StatusCode)) + return + } + compression := res.Header.Get("Grpc-Encoding") + if compression == "" { + compression = CompressionIdentity + } + switch compression { + case CompressionIdentity, CompressionGzip: + default: + // Per https://github.com/grpc/grpc/blob/master/doc/compression.md, we + // should return CodeInternal and specify acceptable compression(s) (in + // addition to setting the Grpc-Accept-Encoding header). + cs.setResponseError(errorf( + CodeInternal, + "unknown compression %q: accepted grpc-encoding values are %v", + compression, + acceptEncodingValue, + )) + return + } + // When there's no body, errors sent from the first-party gRPC servers will + // be in the headers. + if err := extractError(res.Header); err != nil { + cs.setResponseError(err) + return + } + // Success! + cs.response = res + cs.unmarshaler = unmarshaler{r: res.Body, ctype: TypeDefaultGRPC, max: cs.maxReadBytes} +} + +func (cs *clientStream) setResponseError(err error) { + cs.responseErr = err + // The write end of the pipe will now return this error too. + cs.reader.CloseWithError(err) +} + +type serverStream struct { + unmarshaler unmarshaler + marshaler marshaler + writer http.ResponseWriter + reader io.ReadCloser + ctype string +} + +var _ Stream = (*serverStream)(nil) + +func newServerStream( + w http.ResponseWriter, + r io.ReadCloser, + ctype string, + maxReadBytes int64, + gzipResponse bool, +) *serverStream { + return &serverStream{ + unmarshaler: unmarshaler{r: r, ctype: ctype, max: maxReadBytes}, + marshaler: marshaler{w: w, ctype: ctype, gzipGRPC: gzipResponse}, + writer: w, + reader: r, + ctype: ctype, + } +} + +func (ss *serverStream) Receive(msg proto.Message) error { + if err := ss.unmarshaler.Unmarshal(msg); err != nil { + return wrap(CodeInvalidArgument, err) + } + return nil +} + +func (ss *serverStream) CloseReceive() error { + io.Copy(io.Discard, ss.reader) + if err := ss.reader.Close(); err != nil { + return wrap(CodeUnknown, err) + } + return nil +} + +func (ss *serverStream) Send(msg proto.Message) error { + defer ss.flush() + if err := ss.marshaler.Marshal(msg); err != nil { + // we shouldn't ever fail to marshal a proto message + return wrap(CodeInternal, err) + } + return nil +} + +func (ss *serverStream) CloseSend(err error) error { + defer ss.flush() + switch ss.ctype { + case TypeJSON, TypeProtoTwirp: + return ss.sendErrorTwirp(err) + case TypeDefaultGRPC, TypeProtoGRPC: + return ss.sendErrorGRPC(err) + default: + return errorf(CodeInvalidArgument, "unsupported Content-Type %q", ss.ctype) + } +} + +func (ss *serverStream) sendErrorGRPC(err error) error { + if CodeOf(err) == CodeOK { // safe for nil errors + ss.writer.Header().Set("Grpc-Status", strconv.Itoa(int(CodeOK))) + ss.writer.Header().Set("Grpc-Message", "") + ss.writer.Header().Set("Grpc-Status-Details-Bin", "") + return nil + } + s := statusFromError(err) + code := strconv.Itoa(int(s.Code)) + if bin, err := proto.Marshal(s); err != nil { + ss.writer.Header().Set("Grpc-Status", strconv.Itoa(int(CodeInternal))) + ss.writer.Header().Set("Grpc-Message", percentEncode("error marshaling protobuf status with code "+code)) + return errorf(CodeInternal, "couldn't marshal protobuf status: %w", err) + } else { + ss.writer.Header().Set("Grpc-Status", code) + ss.writer.Header().Set("Grpc-Message", percentEncode(s.Message)) + ss.writer.Header().Set("Grpc-Status-Details-Bin", encodeBinaryHeader(bin)) + return nil + } +} + +func (ss *serverStream) sendErrorTwirp(err error) error { + if err == nil { + return nil + } + gs := statusFromError(err) + s := &twirp.Status{ + Code: Code(gs.Code).twirp(), + Message: gs.Message, + } + if te, ok := asTwirpError(err); ok { + s.Code = te.TwirpCode() + } + // Even if the caller sends TypeProtoTwirp, we respond with TypeJSON on + // errors. + ss.writer.Header().Set("Content-Type", TypeJSON) + bs, merr := json.Marshal(s) + if merr != nil { + ss.writer.WriteHeader(http.StatusInternalServerError) + // codes don't need to be escaped in JSON, so this is okay + const tmpl = `{"code": "%s", "msg": "error marshaling error with code %s"}` + // Ignore this error. We're well past the point of no return here. + _, _ = fmt.Fprintf(ss.writer, tmpl, CodeInternal.twirp(), s.Code) + return errorf(CodeInternal, "couldn't marshal Twirp status to JSON: %w", merr) + } + ss.writer.WriteHeader(CodeOf(err).http()) + if _, err = ss.writer.Write(bs); err != nil { + return wrap(CodeUnknown, err) + } + return nil +} + +func (ss *serverStream) flush() { + if f, ok := ss.writer.(http.Flusher); ok { + f.Flush() + } +} + +func statusFromError(err error) *statuspb.Status { + s := &statuspb.Status{ + Code: int32(CodeUnknown), + Message: err.Error(), + } + if re, ok := AsError(err); ok { + s.Code = int32(re.Code()) + s.Details = re.Details() + if e := re.Unwrap(); e != nil { + s.Message = e.Error() // don't repeat code + } + } + return s +} + +func extractError(h http.Header) *Error { + codeHeader := h.Get("Grpc-Status") + codeIsSuccess := (codeHeader == "" || codeHeader == "0") + if codeIsSuccess { + return nil + } + + code, err := strconv.ParseUint(codeHeader, 10 /* base */, 32 /* bitsize */) + if err != nil { + return errorf(CodeUnknown, "gRPC protocol error: got invalid error code %q", codeHeader) + } + message := percentDecode(h.Get("Grpc-Message")) + ret := wrap(Code(code), errors.New(message)) + + detailsBinaryEncoded := h.Get("Grpc-Status-Details-Bin") + if len(detailsBinaryEncoded) > 0 { + detailsBinary, err := decodeBinaryHeader(detailsBinaryEncoded) + if err != nil { + return errorf(CodeUnknown, "server returned invalid grpc-error-details-bin trailer: %w", err) + } + var status statuspb.Status + if err := proto.Unmarshal(detailsBinary, &status); err != nil { + return errorf(CodeUnknown, "server returned invalid protobuf for error details: %w", err) + } + ret.details = status.Details + // Prefer the protobuf-encoded data to the headers (grpc-go does this too). + ret.code = Code(status.Code) + ret.err = errors.New(status.Message) + } + + return ret +}