Skip to content

Commit

Permalink
Merge pull request #35 from lazygophers/luoxin
Browse files Browse the repository at this point in the history
针对类型的默认 tag
  • Loading branch information
Luoxin authored Oct 14, 2024
2 parents d20fdad + 0b2b74e commit f07db7a
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 48 deletions.
8 changes: 8 additions & 0 deletions cli/gen/gen_add_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ var addRpcCmd = &cobra.Command{
}
}

// 给名字变成驼峰ca
if msg == nil {
msg = pb.GetMessage(stringx.ToCamel(v))
if msg != nil {
opt.Model = stringx.ToCamel(v)
}
}

if msg == nil {
log.Errorf("not found model:%v", v)
pterm.Error.Printfln("not found model:%v", v)
Expand Down
22 changes: 13 additions & 9 deletions codegen/generate_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ func generateImpl(pb *PbPackage, rpc *PbRPC) (err error) {
log.Infof("gen impl action %s", rpc.genOption.Action)

args := map[string]any{
"PB": pb,
"RpcName": rpc.Name,
"RequestType": rpc.rpc.RequestType,
"ResponseType": rpc.rpc.ReturnsType,
"PB": pb,
"RpcName": rpc.Name,
"RequestType": rpc.RequestType(),
"ResponseType": rpc.ReturnsType(),
"RequestPackage": rpc.RequestPackage(),
"ResponsePackage": rpc.ResponsePackage(),
}

if rpc.genOption.Model != "" {
Expand Down Expand Up @@ -445,11 +447,13 @@ func generateImplClient(pb *PbPackage, rpc *PbRPC) (err error) {
pterm.Info.Printfln("try generate impl client %s", rpc.Name)

args := map[string]any{
"PB": pb,
"RpcName": rpc.Name,
"RequestType": rpc.rpc.RequestType,
"ResponseType": rpc.rpc.ReturnsType,
"RPC": pbRpc2Route(rpc),
"PB": pb,
"RpcName": rpc.Name,
"RequestType": rpc.RequestType(),
"ResponseType": rpc.ReturnsType(),
"RequestPackage": rpc.RequestPackage(),
"ResponsePackage": rpc.ResponsePackage(),
"RPC": pbRpc2Route(rpc),
}

if rpc.genOption.Model != "" {
Expand Down
95 changes: 77 additions & 18 deletions codegen/generate_struct_tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,33 @@ func InjectTagWriteFile(inputPath string, areas []textArea) error {
return nil
}

func gormTagStr2Map(s string) map[string]string {
func gormTagStr2Map(items []string) map[string]string {
m := make(map[string]string)
for _, v := range strings.Split(s, ";") {
idx := strings.Index(v, ":")
if idx < 0 {
m[v] = ""
} else {
m[v[:idx]] = v[idx+1:]

for _, item := range items {
for _, v := range strings.Split(item, ";") {
idx := strings.Index(v, ":")
if idx < 0 {
m[v] = ""
} else {
m[v[:idx]] = v[idx+1:]
}
}
}

return m
}

func tagStr2Map(s string) map[string]string {
func tagStr2Map(items []string) map[string]string {
m := make(map[string]string)
for _, v := range strings.Split(s, ",") {
idx := strings.Index(v, "=")
if idx < 0 {
m[v] = ""
} else {
m[v[:idx]] = v[idx+1:]
for _, item := range items {
for _, v := range strings.Split(item, ",") {
idx := strings.Index(v, "=")
if idx < 0 {
m[v] = ""
} else {
m[v[:idx]] = v[idx+1:]
}
}
}

Expand Down Expand Up @@ -314,9 +319,64 @@ func InjectTagParseFile(inputPath string) ([]textArea, error) {
}
}

tagsMap := state.Config.DefaultTag[fieldName]
if len(tagsMap) == 0 {
tagsMap = make(map[string]string)
// 先按照类型获取一下
var getFieldType func(xx ast.Expr) string
getFieldType = func(xx ast.Expr) string {
switch x := xx.(type) {
case *ast.Ident:
return x.Name

case *ast.StarExpr:
return "*" + getFieldType(x.X)

case *ast.ArrayType:
return "array"

case *ast.MapType:
return "map"

case *ast.SelectorExpr:
return "*" + getFieldType(x.X)

default:
log.Panicf("unknown type %T", x)
}

return ""
}

fieldType := getFieldType(field.Type)

log.Infof("field type: %s", fieldType)

tagsMap := make(map[string][]string)

if tm := state.Config.DefaultTag[fieldType]; tm != nil {
for k, v := range tm {
tagsMap[k] = append(tagsMap[k], v)
}
}
switch fieldType {
case "int32", "int64", "uint32", "uint64", "sint32", "sint64":

case "float", "double", "float32", "float64":

case "string", "bytes":

case "bool":

default:
if tm := state.Config.DefaultTag["object"]; tm != nil {
for k, v := range tm {
tagsMap[k] = append(tagsMap[k], v)
}
}
}

if tm := state.Config.DefaultTag[fieldName]; tm != nil {
for k, v := range tm {
tagsMap[k] = append(tagsMap[k], v)
}
}

if isModelStruct {
Expand All @@ -329,7 +389,6 @@ func InjectTagParseFile(inputPath string) ([]textArea, error) {
"column": fieldName,
}, injectTags.get("gorm"))
}

}

for key, value := range tagsMap {
Expand Down
4 changes: 3 additions & 1 deletion codegen/generate_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ func GenerateStateTable(pb *PbPackage) (err error) {

// table 文件为覆盖生成
args := map[string]interface{}{
"PB": pb,
"PB": pb,
"EnableErrorNotFound": !state.Config.Tables.DisableErrorNotFound,
"EnableErrorDuplicateKey": !state.Config.Tables.DisableErrorDuplicateKey,
}

// 读取 Models
Expand Down
76 changes: 70 additions & 6 deletions codegen/pbparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,45 @@ type PbRPC struct {
rpc *proto.RPC
Name string

options map[string]map[string]string
comment *PbComment
genOption *PbRpcGenOptions
options map[string]map[string]string
comment *PbComment
genOption *PbRpcGenOptions
requestType string
returnsType string
requestPackage string
responsePackage string
}

func (p *PbRPC) setDefaultPackage(pkg string) {
if p.requestPackage == "" {
p.requestPackage = pkg
}

if p.responsePackage == "" {
p.responsePackage = pkg
}
}

func (p *PbRPC) RPC() *proto.RPC {
return p.rpc
}

func (p *PbRPC) RequestType() string {
return p.requestType
}

func (p *PbRPC) ReturnsType() string {
return p.returnsType
}

func (p *PbRPC) RequestPackage() string {
return p.requestPackage
}

func (p *PbRPC) ResponsePackage() string {
return p.responsePackage
}

func (p *PbRPC) walk() {
for _, option := range p.rpc.Options {
p.options[option.Name] = make(map[string]string, len(option.AggregatedConstants))
Expand Down Expand Up @@ -209,9 +239,6 @@ func (p *PbRPC) walk() {
return
}

log.Info(v)
log.Info(gen)

if gen.Role != "" {
p.genOption.Role = gen.Role
}
Expand Down Expand Up @@ -272,6 +299,33 @@ func (p *PbRPC) walk() {
p.genOption.GenTo = "impl"
}
}

if strings.Contains(p.rpc.RequestType, ".") {
text := p.rpc.RequestType
if idx := strings.LastIndex(text, "."); idx > 0 {
p.requestType = text[idx+1:]
text = text[:idx]
}
if idx := strings.LastIndex(text, "."); idx > 0 {
p.requestPackage = text[idx+1:]
} else {
p.requestPackage = text
}
}

if strings.Contains(p.rpc.ReturnsType, ".") {
text := p.rpc.ReturnsType
if idx := strings.LastIndex(text, "."); idx > 0 {
p.requestType = text[idx+1:]
text = text[:idx]
}
if idx := strings.LastIndex(text, "."); idx > 0 {
p.responsePackage = text[idx+1:]
} else {
p.responsePackage = text
}
}

}

func NewPbRPC(rpc *proto.RPC) *PbRPC {
Expand All @@ -283,7 +337,10 @@ func NewPbRPC(rpc *proto.RPC) *PbRPC {
Method: "POST",
Path: "/" + rpc.Name,
},
requestType: rpc.RequestType,
returnsType: rpc.ReturnsType,
}

p.walk()
return p
}
Expand Down Expand Up @@ -756,6 +813,8 @@ func (p *PbPackage) Walk() {
}),
)

// 处理一下option

if o, ok := p.optionMap["go_package"]; ok {
p.RawGoPackage = o.Value
idx := strings.Index(p.RawGoPackage, ";")
Expand All @@ -778,6 +837,11 @@ func (p *PbPackage) Walk() {
} else {
p.Host = "*"
}

// 回填一些默认值
candy.Each(p.rpcs, func(r *PbRPC) {
r.setDefaultPackage(p.PackageName())
})
}

func NewPbPackage(protoFilePath string, p *proto.Proto) *PbPackage {
Expand Down
4 changes: 2 additions & 2 deletions codegen/template/impl/.rpc.gtpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType }}) (*{{ .PB.GoPackageName }}.{{ .ResponseType }}, error) {
var rsp {{ .PB.GoPackageName }}.{{ .ResponseType }}
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .RequestPackage }}.{{ .RequestType }}) (*{{ .ResponsePackage }}.{{ .ResponseType }}, error) {
var rsp {{ .ResponsePackage }}.{{ .ResponseType }}

return &rsp, nil
}
6 changes: 3 additions & 3 deletions codegen/template/impl/add.rpc.gtpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType }}) (*{{ .PB.GoPackageName }}.{{ .ResponseType }}, error) {
var rsp {{ .PB.GoPackageName }}.{{ .ResponseType }}
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .RequestPackage }}.{{ .RequestType }}) (*{{ .ResponsePackage }}.{{ .ResponseType }}, error) {
var rsp {{ .ResponsePackage }}.{{ .ResponseType }}

{{ ToSmallCamel (TrimPrefix .Model "Model") }} := *req.{{ ToCamel (TrimPrefix .Model "Model") }}
{{ ToSmallCamel (TrimPrefix .Model "Model") }}.{{ ToCamel .PrimaryKey }} = 0
Expand All @@ -12,7 +12,7 @@ func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType
return nil, err
}

rsp.{{ ToCamel (TrimPrefix .Model "Model") }} = &{{ ToSnake (TrimPrefix .Model "Model") }}
rsp.{{ ToCamel (TrimPrefix .Model "Model") }} = &{{ ToSmallCamel (TrimPrefix .Model "Model") }}

return &rsp, nil
}
4 changes: 2 additions & 2 deletions codegen/template/impl/del.rpc.gtpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType }}) (*{{ .PB.GoPackageName }}.{{ .ResponseType }}, error) {
var rsp {{ .PB.GoPackageName }}.{{ .ResponseType }}
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .RequestPackage }}.{{ .RequestType }}) (*{{ .ResponsePackage }}.{{ .ResponseType }}, error) {
var rsp {{ .ResponsePackage }}.{{ .ResponseType }}

err := state.{{ ToCamel (TrimPrefix .Model "Model") }}.
NewScoop().
Expand Down
4 changes: 2 additions & 2 deletions codegen/template/impl/get.rpc.gtpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType }}) (*{{ .PB.GoPackageName }}.{{ .ResponseType }}, error) {
var rsp {{ .PB.GoPackageName }}.{{ .ResponseType }}
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .RequestPackage }}.{{ .RequestType }}) (*{{ .ResponsePackage }}.{{ .ResponseType }}, error) {
var rsp {{ .ResponsePackage }}.{{ .ResponseType }}

{{ ToSmallCamel (TrimPrefix .Model "Model") }}, err := state.{{ ToCamel (TrimPrefix .Model "Model") }}.
NewScoop().
Expand Down
4 changes: 2 additions & 2 deletions codegen/template/impl/list.rpc.gtpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType }}) (*{{ .PB.GoPackageName }}.{{ .ResponseType }}, error) {
var rsp {{ .PB.GoPackageName }}.{{ .ResponseType }}
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .RequestPackage }}.{{ .RequestType }}) (*{{ .ResponsePackage }}.{{ .ResponseType }}, error) {
var rsp {{ .ResponsePackage }}.{{ .ResponseType }}

scoop := state.{{ ToCamel (TrimPrefix .Model "Model") }}.NewScoop()

Expand Down
4 changes: 2 additions & 2 deletions codegen/template/impl/set.rpc.gtpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .PB.GoPackageName }}.{{ .RequestType }}) (*{{ .PB.GoPackageName }}.{{ .ResponseType }}, error) {
var rsp {{ .PB.GoPackageName }}.{{ .ResponseType }}
func {{ .RpcName }}(ctx *lrpc.Ctx, req *{{ .RequestPackage }}.{{ .RequestType }}) (*{{ .ResponsePackage }}.{{ .ResponseType }}, error) {
var rsp {{ .ResponsePackage }}.{{ .ResponseType }}

{{ ToSmallCamel (TrimPrefix .Model "Model") }} := req.{{ ToCamel (TrimPrefix .Model "Model") }}

Expand Down
4 changes: 3 additions & 1 deletion codegen/template/state/table.gtpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ func ConnectDatabase() (err error) {
return err
}

{{ range $key, $value := .Models}} {{TrimPrefix $value "Model"}} = db.NewModel[{{ $.PB.GoPackageName }}.{{ $value }}](Db()).SetNotFound(xerror.NewError(int32({{ $.PB.GoPackageName }}.ErrCode_{{TrimPrefix $value "Model"}}NotFound)))
{{ range $key, $value := .Models}} {{TrimPrefix $value "Model"}} = db.NewModel[{{ $.PB.GoPackageName }}.{{ $value }}](Db()){{ if $.EnableErrorNotFound }}.
SetNotFound(xerror.NewError(int32({{ $.PB.GoPackageName }}.ErrCode_{{TrimPrefix $value "Model"}}NotFound))){{ end }}{{ if $.EnableErrorDuplicateKey }}.
SetDuplicatedKeyError(xerror.NewError(int32({{ $.PB.GoPackageName }}.ErrCode_{{TrimPrefix $value "Model"}}DuplicateKey))){{ end }}
{{ end }}
log.Info("connect mysql successfully")

Expand Down
5 changes: 5 additions & 0 deletions example.codegen.cfg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ tables:
# 关闭自动生成 column 的gorm tag,默认为 false
disable-gorm-tag-column: false

# 关闭关于错误:数据未找到的指定错误生成
disable-error-not_found: false
# 关闭关于错误:唯一键冲突的指定错误生成
disable-error-duplicate_key: false

# go.mod 相关配置
go-mod:
# 对 go.mod 相关 api 请求时用的代理地址,如果不填则使用系统配置的第一个
Expand Down
Loading

0 comments on commit f07db7a

Please sign in to comment.