Skip to content

Commit

Permalink
fix(injectablegen): support embeded provide
Browse files Browse the repository at this point in the history
  • Loading branch information
morlay committed Oct 21, 2024
1 parent 360e3af commit f0776d2
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 65 deletions.
103 changes: 86 additions & 17 deletions devpkg/injectablegen/injectable.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ func (g *injectableGen) GenerateType(c gengo.Context, t *types.Named) error {
values, ok := tags["gengo:injectable:provider"]
if ok {
if len(values) > 0 {
if err := g.genAsProvider(c, t, values[0]); err != nil {
if err := g.genAsProvider(c, t, values[0], false); err != nil {
return err
}
} else {
if err := g.genAsProvider(c, t, ""); err != nil {
if err := g.genAsProvider(c, t, "", false); err != nil {
return err
}
}
Expand All @@ -72,9 +72,33 @@ func (g *injectableGen) GenerateType(c gengo.Context, t *types.Named) error {
return nil
}

func (g *injectableGen) genAsProvider(c gengo.Context, t *types.Named, impl string) error {
switch t.Underlying().(type) {
case *types.Alias:
func (g *injectableGen) GenerateAliasType(c gengo.Context, t *types.Alias) error {
tags, _ := c.Doc(t.Obj())

g.once.Do(func() {
g.init(c)
})

values, ok := tags["gengo:injectable:provider"]
if ok {
if len(values) > 0 {
if err := g.genAsProvider(c, t, values[0], true); err != nil {
return err
}
} else {
if err := g.genAsProvider(c, t, "", true); err != nil {
return err
}
}
}
return nil
}

func (g *injectableGen) genAsProvider(c gengo.Context, t interface {
Obj() *types.TypeName
Underlying() types.Type
}, impl string, forAlias bool) error {
switch x := t.Underlying().(type) {
case *types.Interface:
c.Render(gengo.Snippet{
gengo.T: `
Expand All @@ -96,18 +120,53 @@ func @Type'InjectContext(ctx @contextContext, tpe @Type) (@contextContext) {
"contextWithValue": gengo.ID("context.WithValue"),
})
case *types.Struct:
provideFields := func(sw gengo.SnippetWriter) {
if forAlias {
return
}

for i := 0; i < x.NumFields(); i++ {
f := x.Field(i)
structTag := reflect.StructTag(x.Tag(i))

injectTag, exists := structTag.Lookup("provide")
if exists && injectTag != "-" {
typ := f.Type()
for {
x, ok := typ.(*types.Pointer)
if !ok {
break
}
typ = x.Elem()
}

sw.Render(gengo.Snippet{
gengo.T: `
ctx = @FieldType'InjectContext(ctx, p.@Field)
`,
"Field": gengo.ID(f.Name()),
"FieldType": gengo.ID(f.Type()),
})
}
}
}

if impl != "" {
c.Render(gengo.Snippet{
gengo.T: `
if !forAlias {
c.Render(gengo.Snippet{
gengo.T: `
func (p *@Type) InjectContext(ctx @contextContext) (@contextContext) {
@provideFields
return @injectContext(ctx, p)
}
`,
"Type": gengo.ID(t.Obj()),
"injectContext": gengo.ID(impl + "InjectContext"),
"contextContext": gengo.ID("context.Context"),
})
"Type": gengo.ID(t.Obj()),
"injectContext": gengo.ID(impl + "InjectContext"),
"contextContext": gengo.ID("context.Context"),
"provideFields": provideFields,
})
}

return nil
}
Expand All @@ -126,16 +185,26 @@ func @Type'FromContext(ctx @contextContext) (*@Type, bool) {
func @Type'InjectContext(ctx @contextContext, tpe *@Type) (@contextContext) {
return @contextWithValue(ctx, context@Type{}, tpe)
}
func (p *@Type) InjectContext(ctx @contextContext) (@contextContext) {
return @Type'InjectContext(ctx, p)
}
`,
"Type": gengo.ID(t.Obj()),
"contextContext": gengo.ID("context.Context"),
"contextWithValue": gengo.ID("context.WithValue"),
})

if !forAlias {
c.Render(gengo.Snippet{
gengo.T: `
func (p *@Type) InjectContext(ctx @contextContext) (@contextContext) {
@provideFields
return @Type'InjectContext(ctx, p)
}
`,
"Type": gengo.ID(t.Obj()),
"contextContext": gengo.ID("context.Context"),
"contextWithValue": gengo.ID("context.WithValue"),
"provideFields": provideFields,
})
}
}

return nil
Expand Down Expand Up @@ -187,7 +256,7 @@ if value, ok := @FieldType'FromContext(ctx); ok {
if !strings.Contains(injectTag, ",optional") {
sw.Render(gengo.Snippet{
gengo.T: `else {
return @errorsErrorf("missing provider %T", v.@Field)
return @errorsErrorf("missing provider %T.@Field", v)
}
`,
"Field": gengo.ID(f.Name()),
Expand Down
4 changes: 4 additions & 0 deletions example/apis/org/zz_generated.operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ func init() {
R.Register(courier.NewRouter(&ListOrg{}))
}

func (*ListOrg) ResponseStatusCode() int {
return 200
}

func (*ListOrg) ResponseContent() any {
return new(DataList[Info])
}
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ go 1.23.2

require (
github.com/davecgh/go-spew v1.1.1
github.com/go-courier/logr v0.3.0
github.com/go-courier/logr v0.3.1
github.com/go-json-experiment/json v0.0.0-20240815175050-ebd3a8989ca1
github.com/juju/ansiterm v1.0.0
github.com/octohelm/gengo v0.0.0-20241014043309-2344b8632080
github.com/octohelm/gengo v0.0.0-20241021060200-490be2d0c7f4
github.com/octohelm/x v0.0.0-20241011014327-0fcf864c84d6
github.com/onsi/gomega v1.34.2
golang.org/x/net v0.30.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-courier/logr v0.3.0 h1:0VEQB1b53EmYQ+ZehrIgD8l2IO+WX7TY+CqzlykIFmo=
github.com/go-courier/logr v0.3.0/go.mod h1:OI7f/JCFZ1ZMD5qG3bIJr5WMNnGzd24+II1D9D9w5x4=
github.com/go-courier/logr v0.3.1 h1:RcHM7qpO8OpuV+zFvJMXtJEspTnnYmT6uGiAomwb8X8=
github.com/go-courier/logr v0.3.1/go.mod h1:NQWi+TSv0rS1RfyWHv7MNEI5cNy9NR6k1n8R24uHgdY=
github.com/go-json-experiment/json v0.0.0-20240815175050-ebd3a8989ca1 h1:xcuWappghOVI8iNWoF2OKahVejd1LSVi/v4JED44Amo=
github.com/go-json-experiment/json v0.0.0-20240815175050-ebd3a8989ca1/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
Expand Down Expand Up @@ -30,8 +30,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/octohelm/gengo v0.0.0-20241014043309-2344b8632080 h1:JydhpFfiuBsIXWs4f3GQJDpCmG5LFKgzB7CFkRGWs3s=
github.com/octohelm/gengo v0.0.0-20241014043309-2344b8632080/go.mod h1:7bkbdNmnQEmnVbvdSJdRwvGmm2/3KgvxXSDi/nXCOk8=
github.com/octohelm/gengo v0.0.0-20241021060200-490be2d0c7f4 h1:zd0FrKYk+2aFBbS90nzdXhSpJadMwfZBb9hSamuLEVQ=
github.com/octohelm/gengo v0.0.0-20241021060200-490be2d0c7f4/go.mod h1:7bkbdNmnQEmnVbvdSJdRwvGmm2/3KgvxXSDi/nXCOk8=
github.com/octohelm/x v0.0.0-20241011014327-0fcf864c84d6 h1:4royPn66/B49ftj5Jh7ZWuJ8A2B/2l0UTQ6EFLf29Vk=
github.com/octohelm/x v0.0.0-20241011014327-0fcf864c84d6/go.mod h1:c8k4TZNwZXVCclWxFU8dm67PQ98APzKW3f5JJykQ2uo=
github.com/onsi/ginkgo/v2 v2.20.1 h1:YlVIbqct+ZmnEph770q9Q7NVAz4wwIiVNahee6JyUzo=
Expand Down
2 changes: 1 addition & 1 deletion internal/request/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ type routeHttpHandler struct {
func (h *routeHttpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

ctx = courierhttp.HttpRequestInjectContext(ctx, &courierhttp.HttpRequest{Request: r})
ctx = courierhttp.RequestInjectContext(ctx, r)

info := httprequest.From(r)

Expand Down
4 changes: 1 addition & 3 deletions pkg/courierhttp/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import (
)

// +gengo:injectable:provider
type HttpRequest struct {
*http.Request
}
type Request = http.Request

// +gengo:injectable:provider
type OperationInfo struct {
Expand Down
11 changes: 6 additions & 5 deletions pkg/courierhttp/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package courierhttp

import (
"context"
"fmt"
"github.com/go-courier/logr"
"io"
"net/http"
Expand Down Expand Up @@ -45,7 +46,7 @@ type FileHeader interface {
Header() http.Header
}

type Request = httprequest.Request
type RequestInfo = httprequest.Request

type ResponseSetting interface {
SetStatusCode(statusCode int)
Expand Down Expand Up @@ -121,11 +122,11 @@ type Response[T any] interface {
}

type ErrResponseWriter interface {
WriteErr(ctx context.Context, rw http.ResponseWriter, req Request, err error)
WriteErr(ctx context.Context, rw http.ResponseWriter, req RequestInfo, err error)
}

type ResponseWriter interface {
WriteResponse(ctx context.Context, rw http.ResponseWriter, req Request) error
WriteResponse(ctx context.Context, rw http.ResponseWriter, req RequestInfo) error
}

type errorResponse struct {
Expand Down Expand Up @@ -192,7 +193,7 @@ func (r *response[T]) Meta() courier.Metadata {
return r.meta
}

func (r *response[T]) WriteResponse(ctx context.Context, rw http.ResponseWriter, req Request) (finalErr error) {
func (r *response[T]) WriteResponse(ctx context.Context, rw http.ResponseWriter, req RequestInfo) (finalErr error) {
defer func() {
r.v = nil
if finalErr != nil {
Expand Down Expand Up @@ -278,7 +279,7 @@ func (r *response[T]) WriteResponse(ctx context.Context, rw http.ResponseWriter,
// forward result
rw.WriteHeader(r.statusCode)
if _, err := v.Into(rw); err != nil {
return err
return fmt.Errorf("forward failed: %w", err)
}
default:
if resp == nil {
Expand Down
12 changes: 6 additions & 6 deletions pkg/courierhttp/transport/incoming_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

type IncomingTransport interface {
UnmarshalOperator(ctx context.Context, info courierhttp.Request, op any) error
WriteResponse(ctx context.Context, rw http.ResponseWriter, result any, info courierhttp.Request)
UnmarshalOperator(ctx context.Context, info courierhttp.RequestInfo, op any) error
WriteResponse(ctx context.Context, rw http.ResponseWriter, result any, info courierhttp.RequestInfo)
}

func NewIncomingTransport(ctx context.Context, v any) (IncomingTransport, error) {
Expand All @@ -23,11 +23,11 @@ func NewIncomingTransport(ctx context.Context, v any) (IncomingTransport, error)
type incomingTransport struct {
}

func (t *incomingTransport) UnmarshalOperator(ctx context.Context, ireq courierhttp.Request, op any) error {
func (t *incomingTransport) UnmarshalOperator(ctx context.Context, ireq courierhttp.RequestInfo, op any) error {
return content.UnmarshalRequestInfo(ireq, op)
}

func (i *incomingTransport) WriteResponse(ctx context.Context, rw http.ResponseWriter, ret any, req courierhttp.Request) {
func (i *incomingTransport) WriteResponse(ctx context.Context, rw http.ResponseWriter, ret any, req courierhttp.RequestInfo) {
if upgrader, ok := ret.(Upgrader); ok {
if err := upgrader.Upgrade(rw, req.Underlying()); err != nil {
i.writeErrResp(ctx, rw, err, req)
Expand All @@ -42,13 +42,13 @@ func (i *incomingTransport) WriteResponse(ctx context.Context, rw http.ResponseW
}
}

func (i *incomingTransport) writeResp(ctx context.Context, rw http.ResponseWriter, ret any, req courierhttp.Request) {
func (i *incomingTransport) writeResp(ctx context.Context, rw http.ResponseWriter, ret any, req courierhttp.RequestInfo) {
if err := courierhttp.Wrap(ret).(courierhttp.ResponseWriter).WriteResponse(ctx, rw, req); err != nil {
logr.FromContext(ctx).Error(err)
}
}

func (i *incomingTransport) writeErrResp(ctx context.Context, rw http.ResponseWriter, err error, req courierhttp.Request) {
func (i *incomingTransport) writeErrResp(ctx context.Context, rw http.ResponseWriter, err error, req courierhttp.RequestInfo) {
if err := courierhttp.WrapError(err).(courierhttp.ResponseWriter).WriteResponse(ctx, rw, req); err != nil {
logr.FromContext(ctx).Error(err)
}
Expand Down
36 changes: 13 additions & 23 deletions pkg/courierhttp/zz_generated.injectable.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,38 @@ import (
context "context"
)

type contextHttpRequest struct{}
type contextOperationInfo struct{}

func HttpRequestFromContext(ctx context.Context) (*HttpRequest, bool) {
if v, ok := ctx.Value(contextHttpRequest{}).(*HttpRequest); ok {
func OperationInfoFromContext(ctx context.Context) (*OperationInfo, bool) {
if v, ok := ctx.Value(contextOperationInfo{}).(*OperationInfo); ok {
return v, true
}
return nil, false
}

func HttpRequestInjectContext(ctx context.Context, tpe *HttpRequest) context.Context {
return context.WithValue(ctx, contextHttpRequest{}, tpe)
func OperationInfoInjectContext(ctx context.Context, tpe *OperationInfo) context.Context {
return context.WithValue(ctx, contextOperationInfo{}, tpe)
}
func (p *OperationInfo) InjectContext(ctx context.Context) context.Context {

func (p *HttpRequest) InjectContext(ctx context.Context) context.Context {
return HttpRequestInjectContext(ctx, p)
return OperationInfoInjectContext(ctx, p)
}

func (v *HttpRequest) Init(ctx context.Context) error {
func (v *OperationInfo) Init(ctx context.Context) error {

return nil
}

type contextOperationInfo struct{}
type contextRequest struct{}

func OperationInfoFromContext(ctx context.Context) (*OperationInfo, bool) {
if v, ok := ctx.Value(contextOperationInfo{}).(*OperationInfo); ok {
func RequestFromContext(ctx context.Context) (*Request, bool) {
if v, ok := ctx.Value(contextRequest{}).(*Request); ok {
return v, true
}
return nil, false
}

func OperationInfoInjectContext(ctx context.Context, tpe *OperationInfo) context.Context {
return context.WithValue(ctx, contextOperationInfo{}, tpe)
}

func (p *OperationInfo) InjectContext(ctx context.Context) context.Context {
return OperationInfoInjectContext(ctx, p)
}

func (v *OperationInfo) Init(ctx context.Context) error {

return nil
func RequestInjectContext(ctx context.Context, tpe *Request) context.Context {
return context.WithValue(ctx, contextRequest{}, tpe)
}

type contextRouteDescriber struct{}
Expand Down
8 changes: 4 additions & 4 deletions pkg/validator/internal/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ func (v *validators) defaultRule(t reflect.Type) string {
case reflect.Uint:
return "@uint"
case reflect.Uint8:
return "@uin8"
return "@uint8"
case reflect.Uint16:
return "@uin16"
return "@uint16"
case reflect.Uint32:
return "@uin32"
return "@uint32"
case reflect.Uint64:
return "@uin64"
return "@uint64"
case reflect.Float32:
return "@float"
case reflect.Float64:
Expand Down
2 changes: 2 additions & 0 deletions pkg/validator/validators/zz_generated.runtimedoc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ func runtimeDoc(v any, names ...string) ([]string, bool) {
func (v FloatValidator) RuntimeDoc(names ...string) ([]string, bool) {
if len(names) > 0 {
switch names[0] {
case "BitSize":
return []string{}, true
case "MaxDigits":
return []string{}, true
case "DecimalDigits":
Expand Down

0 comments on commit f0776d2

Please sign in to comment.