Skip to content

Commit

Permalink
feat(operatorgen): support scan custom error struct
Browse files Browse the repository at this point in the history
  • Loading branch information
morlay committed May 6, 2024
1 parent dc50fb8 commit 6b6f6ac
Show file tree
Hide file tree
Showing 38 changed files with 341 additions and 69 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@ on:
jobs:
test:
runs-on: ubuntu-latest
env:
GOEXPERIMENT: "rangefunc"

steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
- uses: actions/setup-go@v4
with:
go-version: '^1.22'

- run: make test.race
- run: make test.race


6 changes: 3 additions & 3 deletions devpkg/clientgen/gen.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package openapi
package clientgen

import (
"context"
Expand Down Expand Up @@ -428,11 +428,11 @@ const (
`,
"enums": gengo.MapSnippet(enumType.Enum, func(enum any) gengo.Snippet {
return gengo.Snippet{gengo.T: `
@NamePrefix'__@Name @Type = @value
@NamePrefix'__@OrgName @Type = @value
`,
"Type": gengo.ID(gengo.UpperCamelCase(name)),
"NamePrefix": gengo.ID(gengo.UpperSnakeCase(name)),
"Name": gengo.ID(gengo.UpperCamelCase(enum.(string))),
"OrgName": gengo.ID(gengo.UpperCamelCase(enum.(string))),
"value": enum,
}
}),
Expand Down
26 changes: 13 additions & 13 deletions devpkg/operatorgen/gen.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package openapi
package operatorgen

import (
"fmt"
Expand All @@ -15,7 +15,7 @@ import (
"github.com/octohelm/courier/pkg/courierhttp"
"github.com/octohelm/gengo/pkg/gengo"
gengotypes "github.com/octohelm/gengo/pkg/types"
typesutil "github.com/octohelm/x/types"
typex "github.com/octohelm/x/types"
)

func init() {
Expand All @@ -36,7 +36,7 @@ func (g *operatorGen) GenerateType(c gengo.Context, named *types.Named) error {
return gengo.ErrSkip
}

if !isCourierOperator(c, typesutil.FromTType(types.NewPointer(named)), g.resolvePkg) {
if !isCourierOperator(c, typex.FromTType(types.NewPointer(named)), g.resolvePkg) {
return gengo.ErrSkip
}

Expand All @@ -46,12 +46,12 @@ func (g *operatorGen) GenerateType(c gengo.Context, named *types.Named) error {
}

func (g *operatorGen) generateReturns(c gengo.Context, named *types.Named) {
method, ok := typesutil.FromTType(types.NewPointer(named)).MethodByName("Output")
method, ok := typex.FromTType(types.NewPointer(named)).MethodByName("Output")
if ok {
results, n := c.Package(named.Obj().Pkg().Path()).ResultsOf(method.(*typesutil.TMethod).Func)
results, n := c.Package(named.Obj().Pkg().Path()).ResultsOf(method.(*typex.TMethod).Func)
if n == 2 {
g.generateSuccessReturn(c, named, results[0])
g.generateErrorsReturn(c, named, method.(*typesutil.TMethod).Func)
g.generateErrorsReturn(c, named, method.(*typex.TMethod).Func)
}
}
}
Expand Down Expand Up @@ -222,9 +222,9 @@ func (g *operatorGen) resolvePkg(c gengo.Context, importPath string) *types.Pack
}

func (g *operatorGen) firstValueOfFunc(c gengo.Context, named *types.Named, name string) (interface{}, bool) {
method, ok := typesutil.FromTType(types.NewPointer(named)).MethodByName(name)
method, ok := typex.FromTType(types.NewPointer(named)).MethodByName(name)
if ok {
fn := method.(*typesutil.TMethod).Func
fn := method.(*typex.TMethod).Func
results, n := c.Package(fn.Pkg().Path()).ResultsOf(fn)
if n == 1 {
for _, r := range results[0] {
Expand All @@ -240,11 +240,11 @@ func (g *operatorGen) firstValueOfFunc(c gengo.Context, named *types.Named, name

var typOperator = reflect.TypeOf((*courier.Operator)(nil)).Elem()

func isCourierOperator(c gengo.Context, tpe typesutil.Type, lookup func(c gengo.Context, importPath string) *types.Package) bool {
func isCourierOperator(c gengo.Context, tpe typex.Type, lookup func(c gengo.Context, importPath string) *types.Package) bool {
switch tpe.(type) {
case *typesutil.RType:
return tpe.Implements(typesutil.FromRType(typOperator))
case *typesutil.TType:
case *typex.RType:
return tpe.Implements(typex.FromRType(typOperator))
case *typex.TType:
pkg := lookup(c, typOperator.PkgPath())
if pkg == nil {
return false
Expand All @@ -253,7 +253,7 @@ func isCourierOperator(c gengo.Context, tpe typesutil.Type, lookup func(c gengo.
if t == nil {
return false
}
return types.Implements(tpe.(*typesutil.TType).Type, t.Type().Underlying().(*types.Interface))
return types.Implements(tpe.(*typex.TType).Type, t.Type().Underlying().(*types.Interface))
}
return false
}
Expand Down
132 changes: 130 additions & 2 deletions devpkg/operatorgen/statuserr_scanner.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package openapi
package operatorgen

import (
"fmt"
typex "github.com/octohelm/x/types"
"go/ast"
"go/constant"
"go/token"
"go/types"
"net/http"
"path/filepath"
"reflect"
"sort"
"strconv"
Expand Down Expand Up @@ -74,6 +79,10 @@ func (s *statusErrScanner) StatusErrorsInFunc(ctx gengo.Context, typeFunc *types
tpe = p.Elem()
}
if named, ok := tpe.(*types.Named); ok {
if isErrWithStatusCodeInterface(named) {
return s.scanErrWithStatusCodeInterface(ctx, named)
}

if isTypeStatusErr(named) {
ast.Inspect(r.Expr, func(node ast.Node) bool {
switch x := node.(type) {
Expand Down Expand Up @@ -128,7 +137,6 @@ func (s *statusErrScanner) appendStateErrs(typeFunc *types.Func, statusErrs ...*

func (s *statusErrScanner) scanStatusErrIsExist(typeFunc *types.Func, pkg gengotypes.Package, obj types.Object, callIdent *ast.Ident, x *ast.CallExpr) bool {
if callIdent.Name == "Wrap" && obj.Pkg().Path() == statusErr.PkgPath() {

code := 0
key := ""
msg := ""
Expand Down Expand Up @@ -171,6 +179,126 @@ func (s *statusErrScanner) scanStatusErrIsExist(typeFunc *types.Func, pkg gengot
return false
}

var (
rtypeErrorWithStatusCode = typex.FromRType(reflect.TypeOf((*statuserror.ErrorWithStatusCode)(nil)).Elem())
)

func isErrWithStatusCodeInterface(named *types.Named) bool {
if named != nil {
return typex.FromTType(types.NewPointer(named)).Implements(rtypeErrorWithStatusCode)
}
return false
}

func (s *statusErrScanner) resolveStateCode(ctx gengo.Context, named *types.Named) (int, bool) {
method, ok := typex.FromTType(types.NewPointer(named)).MethodByName("StatusCode")
if ok {
m := method.(*typex.TMethod)
if m.Func.Pkg() == nil {
return 0, false
}

results, n := ctx.Package(m.Func.Pkg().Path()).ResultsOf(m.Func)
if n == 1 {
for _, r := range results[0] {
if r.Value != nil && r.Value.Kind() == constant.Int {
v, err := strconv.ParseInt(r.Value.String(), 10, 64)
if err == nil {
return int(v), true
}
}
}
}
}

return 0, false
}

func (s *statusErrScanner) scanErrWithStatusCodeInterface(ctx gengo.Context, named *types.Named) (list []*statuserror.StatusErr) {
if named.Obj() == nil {
return nil
}

serr := &statuserror.StatusErr{
Key: filepath.Base(named.Obj().Pkg().Path()) + "." + named.Obj().Name(),
Code: http.StatusInternalServerError,
}

code, ok := s.resolveStateCode(ctx, named)
if ok {
serr.Code = code
}

method, ok := typex.FromTType(types.NewPointer(named)).MethodByName("Error")
if ok {
m := method.(*typex.TMethod)
if m.Func.Pkg() == nil {
return
}

results, n := ctx.Package(m.Func.Pkg().Path()).ResultsOf(m.Func)
if n == 1 {
for _, r := range results[0] {
switch x := r.Expr.(type) {
case *ast.BasicLit:
str, err := strconv.Unquote(x.Value)
if err == nil {
e := &(*serr)
e.Msg = str
list = append(list, e)
}
case *ast.CallExpr:
if selectExpr, ok := x.Fun.(*ast.SelectorExpr); ok {
if selectExpr.Sel.Name == "Sprintf" {
e := &(*serr)
e.Msg = fmtSprintfArgsAsTemplate(x.Args)
list = append(list, e)
}
}
}
}
}
}

return
}

func fmtSprintfArgsAsTemplate(args []ast.Expr) string {
if len(args) == 0 {
return ""
}

f := ""
fArgs := make([]any, 0, len(args))

toString := func(a *ast.BasicLit) string {
switch a.Kind {
case token.STRING:
v, _ := strconv.Unquote(a.Value)
return v
default:
return a.Value
}
}

for i, arg := range args {
switch a := arg.(type) {
case *ast.BasicLit:
if i == 0 {
f = toString(a)
} else {
fArgs = append(fArgs, toString(a))
}
case *ast.SelectorExpr:
fArgs = append(fArgs, fmt.Sprintf("{%s}", a.Sel.Name))
case *ast.Ident:
fArgs = append(fArgs, fmt.Sprintf("{%s}", a.Name))
}
}

return fmt.Sprintf(normalizeFormat(f), fArgs...)
}

func pickStatusErrorsFromDoc(lines []string) []*statuserror.StatusErr {
statusErrorList := make([]*statuserror.StatusErr, 0)

Expand Down
14 changes: 14 additions & 0 deletions devpkg/operatorgen/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package operatorgen

import (
"regexp"
)

// https://pkg.go.dev/fmt#hdr-Printing
var re = regexp.MustCompile(`%(\[([0-9]+)])?(([.0-9]+)|([#-+ 0]))?[vTtbcdoOqxXUeEfFgGps]`)

func normalizeFormat(s string) string {
return re.ReplaceAllStringFunc(s, func(seg string) string {
return "%v"
})
}
30 changes: 30 additions & 0 deletions devpkg/operatorgen/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package operatorgen

import (
"fmt"
"testing"
)

func Test_normalizeFormat(t *testing.T) {
cases := []struct {
actual string
expect string
}{
{
"%s %q %v %.1f %q %%",
"%v %v %v %v %v %%",
},
{
"%[2]s %[1]q",
"%v %v",
},
}

for _, tc := range cases {
t.Run(fmt.Sprintf("%q", tc.actual), func(t *testing.T) {
if expect := normalizeFormat(tc.actual); expect != tc.expect {
t.Fatalf("expect: %q, actual: %q", tc.expect, tc.actual)
}
})
}
}
2 changes: 1 addition & 1 deletion example/apis/blob/zz_generated.operator.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Package blob GENERATED BY gengo:operator
Package blob GENERATED BY gengo:operator
DON'T EDIT THIS FILE
*/
package blob
Expand Down
2 changes: 1 addition & 1 deletion example/apis/blob/zz_generated.runtimedoc.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Package blob GENERATED BY gengo:runtimedoc
Package blob GENERATED BY gengo:runtimedoc
DON'T EDIT THIS FILE
*/
package blob
Expand Down
17 changes: 17 additions & 0 deletions example/apis/org/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org

import (
"fmt"

"github.com/octohelm/courier/example/pkg/statuserr"
)

type ErrNotFound struct {
statuserr.NotFound

OrgName string
}

func (e ErrNotFound) Error() string {
return fmt.Sprintf("%s: 组织不存在", e.OrgName)
}
2 changes: 1 addition & 1 deletion example/apis/org/operator/zz_generated.runtimedoc.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Package operator GENERATED BY gengo:runtimedoc
Package operator GENERATED BY gengo:runtimedoc
DON'T EDIT THIS FILE
*/
package operator
Expand Down
10 changes: 3 additions & 7 deletions example/apis/org/org__get.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@ package org

import (
"context"
"net/http"
"time"

"github.com/octohelm/courier/example/apis/org/operator"
"github.com/octohelm/courier/pkg/courier"

"github.com/pkg/errors"

"github.com/octohelm/courier/pkg/courierhttp"
"github.com/octohelm/courier/pkg/statuserror"
)

func (GetOrg) MiddleOperators() courier.MiddleOperators {
Expand All @@ -23,12 +18,13 @@ func (GetOrg) MiddleOperators() courier.MiddleOperators {
// 查询组织信息
type GetOrg struct {
courierhttp.MethodGet `path:"/:orgName"`
OrgName string `name:"orgName" in:"path" `

OrgName string `name:"orgName" in:"path" `
}

func (c *GetOrg) Output(ctx context.Context) (any, error) {
if c.OrgName == "NotFound" {
return nil, statuserror.Wrap(errors.New("NotFound"), http.StatusNotFound, "NotFound")
return nil, &ErrNotFound{OrgName: c.OrgName}
}

return &Detail{
Expand Down
Loading

0 comments on commit 6b6f6ac

Please sign in to comment.