From 1afdb0d6e6b0926ea303c73096a336158e061929 Mon Sep 17 00:00:00 2001 From: Lukas Jenicek Date: Mon, 30 Sep 2024 12:30:05 +0200 Subject: [PATCH] call OnRequest before handler is called and handle err --- _examples/golang-basics/example.gen.go | 41 +++++++++++++++++------ _examples/golang-imports/api.gen.go | 46 ++++++++++++++++++++------ helpers.go.tmpl | 23 ++++++++++--- server.go.tmpl | 17 +++++++--- types.go.tmpl | 2 +- 5 files changed, 97 insertions(+), 32 deletions(-) diff --git a/_examples/golang-basics/example.gen.go b/_examples/golang-basics/example.gen.go index af47cad..2efa6ea 100644 --- a/_examples/golang-basics/example.gen.go +++ b/_examples/golang-basics/example.gen.go @@ -1,6 +1,6 @@ // example v0.0.1 05b7a5c86b98738f4fe6ce9bb1fccd4af064847a // -- -// Code generated by webrpc-gen@v0.19.3-11-g71ce490 with ../../../gen-golang generator. DO NOT EDIT. +// Code generated by webrpc-gen@v0.19.3-14-g44bb43f with ../../../gen-golang generator. DO NOT EDIT. // // webrpc-gen -schema=example.ridl -target=../../../gen-golang -pkg=main -server -client -legacyErrors -fixEmptyArrays -out=./example.gen.go package main @@ -247,11 +247,7 @@ func (s *exampleServiceServer) ServeHTTP(w http.ResponseWriter, r *http.Request) ctx = context.WithValue(ctx, HTTPResponseWriterCtxKey, w) ctx = context.WithValue(ctx, HTTPRequestCtxKey, r) ctx = context.WithValue(ctx, ServiceNameCtxKey, "ExampleService") - ctx = context.WithValue(ctx, MethodAnnotationsCtxKey, methodAnnotations[r.URL.Path]) - - if s.OnRequest != nil { - s.OnRequest(w, r) - } + ctx = context.WithValue(ctx, methodAnnotationsCtxKey, methodAnnotations[r.URL.Path]) var handler func(ctx context.Context, w http.ResponseWriter, r *http.Request) switch r.URL.Path { @@ -288,6 +284,17 @@ func (s *exampleServiceServer) ServeHTTP(w http.ResponseWriter, r *http.Request) switch contentType { case "application/json": + if s.OnRequest != nil { + if err := s.OnRequest(w, r); err != nil { + rpcErr, ok := err.(WebRPCError) + if !ok { + rpcErr = ErrWebrpcEndpoint.WithCause(err) + } + s.sendErrorJSON(w, r, rpcErr) + return + } + } + handler(ctx, w, r) default: err := ErrWebrpcBadRequest.WithCause(fmt.Errorf("unsupported Content-Type %q (only application/json is allowed)", r.Header.Get("Content-Type"))) @@ -766,6 +773,12 @@ func HTTPRequestHeaders(ctx context.Context) (http.Header, bool) { // Helpers // +type MethodCtx struct { + Name string + Service string + Annotations map[string]string +} + type contextKey struct { name string } @@ -784,7 +797,7 @@ var ( MethodNameCtxKey = &contextKey{"MethodName"} - MethodAnnotationsCtxKey = &contextKey{"MethodAnnotations"} + methodAnnotationsCtxKey = &contextKey{"MethodAnnotations"} ) func ServiceNameFromContext(ctx context.Context) string { @@ -802,10 +815,18 @@ func RequestFromContext(ctx context.Context) *http.Request { return r } -func MethodAnnotationsFromContext(ctx context.Context) map[string]string { - annotations, _ := ctx.Value(MethodAnnotationsCtxKey).(map[string]string) - return annotations +func MethodFromContext(ctx context.Context) MethodCtx { + name, _ := ctx.Value(MethodNameCtxKey).(string) + service, _ := ctx.Value(ServiceNameCtxKey).(string) + annotations, _ := ctx.Value(methodAnnotationsCtxKey).(map[string]string) + + return MethodCtx{ + Name: name, + Service: service, + Annotations: annotations, + } } + func ResponseWriterFromContext(ctx context.Context) http.ResponseWriter { w, _ := ctx.Value(HTTPResponseWriterCtxKey).(http.ResponseWriter) return w diff --git a/_examples/golang-imports/api.gen.go b/_examples/golang-imports/api.gen.go index 213d8fe..c11695d 100644 --- a/_examples/golang-imports/api.gen.go +++ b/_examples/golang-imports/api.gen.go @@ -1,6 +1,6 @@ // example-api-service v1.0.0 cae4e128f4fb4c938bfe1ea312deeea3dfd6b6af // -- -// Code generated by webrpc-gen@v0.19.3-11-g71ce490 with ../../../gen-golang generator. DO NOT EDIT. +// Code generated by webrpc-gen@v0.19.3-14-g44bb43f with ../../../gen-golang generator. DO NOT EDIT. // // webrpc-gen -schema=./proto/api.ridl -target=../../../gen-golang -out=./api.gen.go -pkg=main -server -client -legacyErrors=true -fmt=false package main @@ -82,7 +82,9 @@ func (x *Location) Is(values ...Location) bool { } } return false -}var ( +} + +var ( methodAnnotations = map[string]map[string]string{ "/rpc/ExampleAPI/Ping": {}, "/rpc/ExampleAPI/Status": {}, @@ -159,11 +161,7 @@ func (s *exampleAPIServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, HTTPResponseWriterCtxKey, w) ctx = context.WithValue(ctx, HTTPRequestCtxKey, r) ctx = context.WithValue(ctx, ServiceNameCtxKey, "ExampleAPI") - ctx = context.WithValue(ctx, MethodAnnotationsCtxKey, methodAnnotations[r.URL.Path]) - - if s.OnRequest != nil { - s.OnRequest(w, r) - } + ctx = context.WithValue(ctx, methodAnnotationsCtxKey, methodAnnotations[r.URL.Path]) var handler func(ctx context.Context, w http.ResponseWriter, r *http.Request) switch r.URL.Path { @@ -194,6 +192,17 @@ func (s *exampleAPIServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch contentType { case "application/json": + if s.OnRequest != nil { + if err := s.OnRequest(w, r); err != nil { + rpcErr, ok := err.(WebRPCError) + if !ok { + rpcErr = ErrWebrpcEndpoint.WithCause(err) + } + s.sendErrorJSON(w, r, rpcErr) + return + } + } + handler(ctx, w, r) default: err := ErrWebrpcBadRequest.WithCause(fmt.Errorf("unsupported Content-Type %q (only application/json is allowed)", r.Header.Get("Content-Type"))) @@ -504,6 +513,12 @@ func HTTPRequestHeaders(ctx context.Context) (http.Header, bool) { // Helpers // +type MethodCtx struct { + Name string + Service string + Annotations map[string]string +} + type contextKey struct { name string } @@ -522,7 +537,7 @@ var ( MethodNameCtxKey = &contextKey{"MethodName"} - MethodAnnotationsCtxKey = &contextKey{"MethodAnnotations"} + methodAnnotationsCtxKey = &contextKey{"MethodAnnotations"} ) func ServiceNameFromContext(ctx context.Context) string { @@ -540,10 +555,19 @@ func RequestFromContext(ctx context.Context) *http.Request { return r } -func MethodAnnotationsFromContext(ctx context.Context) map[string]string { - annotations, _ := ctx.Value(MethodAnnotationsCtxKey).(map[string]string) - return annotations +func MethodFromContext(ctx context.Context) MethodCtx { + name, _ := ctx.Value(MethodNameCtxKey).(string) + service, _ := ctx.Value(ServiceNameCtxKey).(string) + annotations, _ := ctx.Value(methodAnnotationsCtxKey).(map[string]string) + + return MethodCtx{ + Name: name, + Service: service, + Annotations: annotations, + } } + + func ResponseWriterFromContext(ctx context.Context) http.ResponseWriter { w, _ := ctx.Value(HTTPResponseWriterCtxKey).(http.ResponseWriter) return w diff --git a/helpers.go.tmpl b/helpers.go.tmpl index 32036d8..3a7f2ad 100644 --- a/helpers.go.tmpl +++ b/helpers.go.tmpl @@ -5,6 +5,12 @@ // Helpers // +type MethodCtx struct { + Name string + Service string + Annotations map[string]string +} + type contextKey struct { name string } @@ -27,7 +33,7 @@ var ( MethodNameCtxKey = &contextKey{"MethodName"} - MethodAnnotationsCtxKey = &contextKey{"MethodAnnotations"} + methodAnnotationsCtxKey = &contextKey{"MethodAnnotations"} ) func ServiceNameFromContext(ctx context.Context) string { @@ -45,12 +51,19 @@ func RequestFromContext(ctx context.Context) *http.Request { return r } -func MethodAnnotationsFromContext(ctx context.Context) map[string]string { - annotations, _ := ctx.Value(MethodAnnotationsCtxKey).(map[string]string) - return annotations +func MethodFromContext(ctx context.Context) MethodCtx { + name, _ := ctx.Value(MethodNameCtxKey).(string) + service, _ := ctx.Value(ServiceNameCtxKey).(string) + annotations, _ := ctx.Value(methodAnnotationsCtxKey).(map[string]string) + + return MethodCtx{ + Name: name, + Service: service, + Annotations: annotations, + } } -{{- if $opts.server}} +{{ if $opts.server}} func ResponseWriterFromContext(ctx context.Context) http.ResponseWriter { w, _ := ctx.Value(HTTPResponseWriterCtxKey).(http.ResponseWriter) return w diff --git a/server.go.tmpl b/server.go.tmpl index 87039a1..fba3aa2 100644 --- a/server.go.tmpl +++ b/server.go.tmpl @@ -42,11 +42,7 @@ func (s *{{$serviceName}}) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, HTTPResponseWriterCtxKey, w) ctx = context.WithValue(ctx, HTTPRequestCtxKey, r) ctx = context.WithValue(ctx, ServiceNameCtxKey, "{{.Name}}") - ctx = context.WithValue(ctx, MethodAnnotationsCtxKey, methodAnnotations[r.URL.Path]) - - if s.OnRequest != nil { - s.OnRequest(w, r) - } + ctx = context.WithValue(ctx, methodAnnotationsCtxKey, methodAnnotations[r.URL.Path]) var handler func(ctx context.Context, w http.ResponseWriter, r *http.Request) switch r.URL.Path { @@ -75,6 +71,17 @@ func (s *{{$serviceName}}) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch contentType { case "application/json": + if s.OnRequest != nil { + if err := s.OnRequest(w, r); err != nil { + rpcErr, ok := err.(WebRPCError) + if !ok { + rpcErr = ErrWebrpcEndpoint.WithCause(err) + } + s.sendErrorJSON(w, r, rpcErr) + return + } + } + handler(ctx, w, r) default: err := ErrWebrpcBadRequest.WithCause(fmt.Errorf("unsupported Content-Type %q (only application/json is allowed)", r.Header.Get("Content-Type"))) diff --git a/types.go.tmpl b/types.go.tmpl index 5e0af53..1a1b2dd 100644 --- a/types.go.tmpl +++ b/types.go.tmpl @@ -29,7 +29,7 @@ {{- end }} -{{- range $_, $service := $services -}} +{{ range $_, $service := $services -}} var ( methodAnnotations = map[string]map[string]string{ {{- range $_, $method := $service.Methods }}