diff --git a/go-runtime/compile/schema.go b/go-runtime/compile/schema.go index 1904b1a380..7a19b968cd 100644 --- a/go-runtime/compile/schema.go +++ b/go-runtime/compile/schema.go @@ -10,6 +10,7 @@ import ( "reflect" "strings" "sync" + "unicode" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" @@ -169,12 +170,16 @@ func visitFile(pctx *parseContext, node *ast.File) error { return nil } -func isType[T types.Type](t types.Type) bool { - if _, ok := t.(*types.Named); ok { - t = t.Underlying() +func typeCheck[T types.Type](original types.Object) (*T, bool) { + t := original.Type() + if _, ok := original.Type().(*types.Named); ok { + t = original.Type().Underlying() } - _, ok := t.(T) - return ok + + if t, ok := t.(T); ok { + return &t, true + } + return nil, false } func checkSignature(sig *types.Signature) (req, resp *types.Var, err error) { @@ -190,11 +195,21 @@ func checkSignature(sig *types.Signature) (req, resp *types.Var, err error) { if !types.AssertableTo(contextIfaceType(), params.At(0).Type()) { return nil, nil, fmt.Errorf("first parameter must be of type context.Context but is %s", params.At(0).Type()) } + if params.Len() == 2 { - if !isType[*types.Struct](params.At(1).Type()) { - return nil, nil, fmt.Errorf("second parameter must be a struct but is %s", params.At(1).Type()) + structParam := results.At(0) + if s, ok := typeCheck[*types.Struct](structParam); ok { + for i := 0; i < (*s).NumFields(); i++ { + fieldName := (*s).Field(i).Name() + if len(fieldName) > 0 && unicode.IsLower(rune(fieldName[0])) { + return nil, nil, fmt.Errorf("params field %s must be exported by starting with an uppercase letter", fieldName) + } + } + } else { + return nil, nil, fmt.Errorf("second parameter must be a struct but is %s", structParam.Type()) } - req = params.At(1) + + req = structParam } if results.Len() > 2 { @@ -207,8 +222,16 @@ func checkSignature(sig *types.Signature) (req, resp *types.Var, err error) { return nil, nil, fmt.Errorf("must return an error but is %s", results.At(0).Type()) } if results.Len() == 2 { - if !isType[*types.Struct](results.At(0).Type()) { - return nil, nil, fmt.Errorf("first result must be a struct but is %s", results.At(0).Type()) + structResult := results.At(0) + if s, ok := typeCheck[*types.Struct](structResult); ok { + for i := 0; i < (*s).NumFields(); i++ { + fieldName := (*s).Field(i).Name() + if len(fieldName) > 0 && unicode.IsLower(rune(fieldName[0])) { + return nil, nil, fmt.Errorf("results field %s must be exported by starting with an uppercase letter", fieldName) + } + } + } else { + return nil, nil, fmt.Errorf("first result must be a struct but is %s", structResult.Type()) } resp = results.At(0) }