diff --git a/Makefile b/Makefile index b511ebb..ea94110 100644 --- a/Makefile +++ b/Makefile @@ -19,3 +19,5 @@ lint: ## lint go install github.com/golangci/golangci-lint/cmd/golangci-lint@lat help: ## Show this help @egrep '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | sed 's/Makefile://' | awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n\nTargets:\n"} /^[a-z0-9A-Z_-]+:.*?##/ { printf " \033[36m%-30s\033[0m %s\n", $$1, $$2 }' + +# make build && ./dist/cli_darwin_arm64/codegen gen -i ../example/example.proto -d add-rpc -m user -a add:,admin diff --git a/README.md b/README.md index 809fe08..2653bc2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ - go.mod - go.sum - .editorconfig + - orm_gen.go(gorm 配套的接口协议的生成) + - table_gen.go(表名的生成) - state - table.go(数据库相关的内容) - cache.go(缓存相关的内容) @@ -59,7 +61,7 @@ 4. 执行命令 ```bash - codegen gen pb -i + codegen gen pb -i ``` ### 注意事项 diff --git a/cli/gen_add_rpc.go b/cli/gen_add_rpc.go new file mode 100644 index 0000000..656877b --- /dev/null +++ b/cli/gen_add_rpc.go @@ -0,0 +1,94 @@ +package cli + +import ( + "errors" + "github.com/lazygophers/codegen/codegen" + "github.com/lazygophers/log" + "github.com/lazygophers/utils/stringx" + "github.com/pterm/pterm" + "github.com/spf13/cobra" + "strings" +) + +var genAddRpcCmd = &cobra.Command{ + Use: "add-rpc", + Short: "add rpc to proto with model", + RunE: func(cmd *cobra.Command, args []string) (err error) { + opt := codegen.NewAddRpcOption() + + if v, err := cmd.Flags().GetString("default-role"); err == nil { + opt.DefaultRole = v + } + + if v, err := cmd.Flags().GetString("gen-to"); err == nil { + opt.GenTo = v + } + + var msg *codegen.PbMessage + if v, err := cmd.Flags().GetString("model"); err == nil && v != "" { + // 先看一下加了 Model 的 时候存在 + { + model := stringx.ToSnake(v) + model = strings.TrimPrefix(model, "model_") + model = "model_" + model + model = stringx.ToCamel(model) + + msg = pb.GetMessage(model) + if msg != nil { + opt.Model = model + } + } + + // 直接找名字的 + if msg == nil { + msg = pb.GetMessage(v) + if msg != nil { + opt.Model = v + } + } + + if msg == nil { + log.Errorf("not found model:%v", v) + pterm.Error.Printfln("not found model:%v", v) + return errors.New("not found model") + } + } else { + log.Errorf("missed argument `model`") + pterm.Error.Printfln("missing argument `model`") + return errors.New("missing argument `model`") + } + + if v, err := cmd.Flags().GetString("action"); err == nil && v != "" { + opt.ParseActions(v) + } else { + log.Errorf("missed argument `action`") + pterm.Error.Printfln("missing argument `action`") + return errors.New("missing argument `action`") + } + + if v, err := cmd.Flags().GetString("list-option"); err == nil && v != "" { + opt.ParseListOption(v, msg) + } + + err = codegen.GenerateAddRpc(pb, msg, opt) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + return err + }, +} + +func init() { + genAddRpcCmd.Flags().StringP("model", "m", "", "model name,must be specified") + genAddRpcCmd.Flags().StringP("gen-to", "t", "", "generate go source code path, will be used in gen go source code") + + genAddRpcCmd.Flags().StringP("action", "a", "", "action for adding, segmente by ';',\nsupports: add/set/get/list/del/update") + + genAddRpcCmd.Flags().StringP("list-option", "l", "", "list options, segmente by ';'") + + genAddRpcCmd.Flags().String("default-role", "", "default role") + + genCmd.AddCommand(genAddRpcCmd) +} diff --git a/cli/gen_state.go b/cli/gen_state.go index 2d5669c..59f77c0 100644 --- a/cli/gen_state.go +++ b/cli/gen_state.go @@ -36,5 +36,5 @@ func runGenState(cmd *cobra.Command, args []string) (err error) { } func init() { - rootCmd.AddCommand(genStateCmd) + genCmd.AddCommand(genStateCmd) } diff --git a/cli/gen_table.go b/cli/gen_table.go index 144114e..73665e4 100644 --- a/cli/gen_table.go +++ b/cli/gen_table.go @@ -14,6 +14,18 @@ var genTableCmd = &cobra.Command{ } func runGenTable(cmd *cobra.Command, args []string) (err error) { + err = codegen.GenerateOrm(pb) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + err = codegen.GenerateTableName(pb) + if err != nil { + log.Errorf("err:%v", err) + return err + } + err = codegen.GenerateTable(pb) if err != nil { log.Errorf("err:%v", err) diff --git a/codegen/generate_add_rpc.go b/codegen/generate_add_rpc.go new file mode 100644 index 0000000..41f7d34 --- /dev/null +++ b/codegen/generate_add_rpc.go @@ -0,0 +1,308 @@ +package codegen + +import ( + "bytes" + "fmt" + "github.com/lazygophers/log" + "github.com/lazygophers/utils/candy" + "github.com/pterm/pterm" + "os" + "strings" +) + +type AddRpcOptionAction struct { + Roles []string +} + +type AddRpcOption struct { + Model string + GenTo string + + DefaultRole string + + Action map[string]*AddRpcOptionAction + + ListOptions map[string]string +} + +func (p *AddRpcOption) ParseActions(s string) { + candy.Each(strings.Split(s, ";"), func(item string) { + idx := strings.Index(item, ":") + if idx < 0 { + p.Action[item] = &AddRpcOptionAction{} + return + } + + action := item[:idx] + p.Action[action] = &AddRpcOptionAction{} + + candy.Each(strings.Split(item[idx+1:], ","), func(item string) { + // TODO: 更多的格式解析 + p.Action[action].Roles = append(p.Action[action].Roles, item) + }) + }) + + for _, action := range p.Action { + action.Roles = candy.Map(action.Roles, func(s string) string { + if s == "" { + return p.DefaultRole + } + + return s + }) + + if len(action.Roles) == 0 { + action.Roles = append(action.Roles, p.DefaultRole) + } + + action.Roles = candy.Unique(action.Roles) + } +} + +func (p *AddRpcOption) ParseListOption(s string, msg *PbMessage) { + candy.Each(strings.Split(s, ";"), func(item string) { + var option string + var optionType string + + idx := strings.Index(item, ":") + if idx < 0 { + option = item + + // 如果没填,会按照字段类型填充默认数据 + if field, ok := msg.normalFields[option]; ok { + optionType = field.Type() + } + + } else { + option = item[:idx] + optionType = item[idx+1:] + } + + p.ListOptions[option] = optionType + }) +} + +func NewAddRpcOption() *AddRpcOption { + return &AddRpcOption{ + Action: map[string]*AddRpcOptionAction{}, + ListOptions: map[string]string{}, + } +} + +func GenerateAddRpc(pb *PbPackage, msg *PbMessage, opt *AddRpcOption) (err error) { + // 重新加载一下 pb 文件 + pb, err = ParseProto(pb.protoFilePath, true) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + // 找到最后一个 service 里面 rpc 的位置 + var lastRpcPos int + { + rpc := candy.Last(pb.RPCs()) + if rpc != nil { + // 找到这一个 rpc 的结尾 + s := rpc.RPC().Position.Offset + e := s + + for e <= len(pb.ProtoBuffer) { + for e < len(pb.ProtoBuffer) && + pb.ProtoBuffer[e] != '\n' { + e++ + } + + line := pb.ProtoBuffer[s:e] + + line = strings.ReplaceAll(line, " ", "") + line = strings.ReplaceAll(line, "\r", "") + line = strings.ReplaceAll(line, "\t", "") + if line == "};" || line == "}" { + lastRpcPos = e // 当前行的换行符 + break + } + + s = e + 1 + e = s + } + } + } + + if lastRpcPos == 0 { + service := candy.Last(pb.Services()) + if service != nil { + // 找到这一个 rpc 的结尾 + s := service.Service().Position.Offset + e := s + + for e <= len(pb.ProtoBuffer) { + for e < len(pb.ProtoBuffer) && + pb.ProtoBuffer[e] != '\n' { + e++ + } + + line := pb.ProtoBuffer[s:e] + + line = strings.ReplaceAll(line, " ", "") + line = strings.ReplaceAll(line, "\r", "") + line = strings.ReplaceAll(line, "\t", "") + if line == "}" { + lastRpcPos = s - 1 + break + } + + s = e + 1 + e = s + } + } + } + + if lastRpcPos == 0 { + pterm.Fatal.Printfln("not found service in %s, please add it", pb.ProtoFileName()) + return fmt.Errorf("not found service in %s", pb.ProtoFileName()) + } + + // NOTE: 寻找主键 + var pkField *PbNormalField + { + for _, field := range msg.normalFields { + // 先简单粗暴用 id 当作主键,后面再改 + if field.Name == "id" { + pkField = field + break + } + } + } + + var rpcBlock string + + for action, actionOpt := range opt.Action { + for _, role := range actionOpt.Roles { + args := map[string]interface{}{ + "PB": pb, + "Model": opt.Model, + "Role": role, + "Action": action, + "GenTo": opt.GenTo, + "ListOptions": opt.ListOptions, + } + + if pkField != nil { + args["PprimaryKey"] = pkField.Name + args["PprimaryKeyType"] = pkField.Type() + } + + var rpcName string + // 先获取一下 rpc 的名字 + { + tpl, err := GetTemplate(TemplateTypeProtoRpcName, action) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + b := log.GetBuffer() + + if role == "default" || role == "def" { + role = "" + } + + err = tpl.Execute(b, map[string]interface{}{ + "PB": pb, + "Model": opt.Model, + "Role": role, + }) + + //rpcName = b.String() + rpcName = strings.ReplaceAll(b.String(), " ", "") + rpcName = strings.ReplaceAll(rpcName, "\n", "") + rpcName = strings.ReplaceAll(rpcName, "\r", "") + rpcName = strings.ReplaceAll(rpcName, "\t", "") + } + + log.Infof("try add rpc %s", rpcName) + + args["RpcName"] = rpcName + args["RequestType"] = rpcName + "Req" + args["ResponseType"] = rpcName + "Rsp" + + // 处理 server.rpc + if rpc := pb.GetRPC(rpcName); rpc == nil { + tpl, err := GetTemplate(TemplateTypeProtoService) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + var b bytes.Buffer + err = tpl.Execute(&b, args) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + rpcBlock += b.String() + } + + // 处理 request + if req := pb.GetMessage(rpcName + "Req"); req == nil { + tpl, err := GetTemplate(TemplateTypeProtoRpcReq, action) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + log.Info(args["ListOptions"]) + + var b bytes.Buffer + err = tpl.Execute(&b, args) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + pb.ProtoBuffer += "\n" + pb.ProtoBuffer += b.String() + } + + // 处理 response + if rsp := pb.GetMessage(rpcName + "Rsp"); rsp == nil { + tpl, err := GetTemplate(TemplateTypeProtoRpcResp, action) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + var b bytes.Buffer + err = tpl.Execute(&b, args) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + pb.ProtoBuffer += "\n" + pb.ProtoBuffer += b.String() + } + } + } + + if rpcBlock != "" { + b := bytes.NewBufferString(pb.ProtoBuffer[:lastRpcPos]) + + b.WriteByte('\n') + + b.WriteString(rpcBlock) + + b.WriteString(pb.ProtoBuffer[lastRpcPos:]) + + pb.ProtoBuffer = b.String() + } + + err = os.WriteFile(pb.protoFilePath, []byte(pb.ProtoBuffer), 0666) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + return nil +} diff --git a/codegen/generate_struct_tag.go b/codegen/generate_struct_tag.go index cb23ec6..b86879e 100644 --- a/codegen/generate_struct_tag.go +++ b/codegen/generate_struct_tag.go @@ -122,7 +122,7 @@ func InjectTagWriteFile(inputPath string, areas []textArea) error { endIdx := bytes.LastIndex(contents[area.Start-1:area.End-1], []byte("`")) + area.Start - 1 log.Infof("append custom tags to %s at %d", contents[area.Start-1:endIdx], endIdx) - pterm.Info.Printfln("append custom tags to %s at %s", pterm.FgBlack.Sprint(pterm.BgCyan.Sprintf("%s", contents[area.Start-1:endIdx])), pterm.FgMagenta.Sprint(endIdx)) + pterm.Info.Printfln("append custom tags to %s at %s", pterm.FgBlack.Sprint(pterm.BgWhite.Sprintf("%s", contents[area.Start-1:endIdx])), pterm.FgMagenta.Sprint(endIdx)) b.Write(contents[lastEnd:endIdx]) b.WriteString(" ") diff --git a/codegen/generate_table.go b/codegen/generate_table.go index fb30816..28c02ca 100644 --- a/codegen/generate_table.go +++ b/codegen/generate_table.go @@ -17,12 +17,6 @@ func GenerateTable(pb *PbPackage) (err error) { // table 文件为覆盖生成 args := map[string]interface{}{ "PB": pb, - "GoImports": []string{ - pb.GoPackage(), - "github.com/lazygophers/log", - "github.com/lazygophers/utils/common", - "github.com/lazygophers/utils/db", - }, } // 读取 Models @@ -41,6 +35,7 @@ func GenerateTable(pb *PbPackage) (err error) { args["Models"] = models } + // 生成 table.go tpl, err := GetTemplate(TemplateTypeStateTable) if err != nil { log.Errorf("err:%v", err) @@ -70,3 +65,117 @@ func GenerateTable(pb *PbPackage) (err error) { return nil } + +func GenerateOrm(pb *PbPackage) (err error) { + err = initStateDirectory(pb) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + // table 文件为覆盖生成 + args := map[string]interface{}{ + "PB": pb, + } + + // 读取 Models + { + var models []string + candy.Each(pb.Messages(), func(message *PbMessage) { + if !message.NeedOrm() { + return + } + + log.Infof("find orm object %s", message.Name) + + models = append(models, message.Name) + }) + + args["Models"] = models + } + + // 生成 table.go + tpl, err := GetTemplate(TemplateTypeOrm) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + file, err := os.OpenFile(GetPath(PathTypeOrm, pb), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, fs.FileMode(0666)) + if err != nil { + log.Errorf("err:%v", err) + return err + } + defer file.Close() + + _, err = file.Write(getFileHeader(pb)) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + err = tpl.Execute(file, args) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + return nil +} + +func GenerateTableName(pb *PbPackage) (err error) { + err = initStateDirectory(pb) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + // table 文件为覆盖生成 + args := map[string]interface{}{ + "PB": pb, + } + + // 读取 Models + { + var models []string + candy.Each(pb.Messages(), func(message *PbMessage) { + if !message.IsTable() { + return + } + + log.Infof("find table %s", message.Name) + + models = append(models, message.Name) + }) + + args["Models"] = models + } + + // 生成 table.go + tpl, err := GetTemplate(TemplateTypeTableName) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + file, err := os.OpenFile(GetPath(PathTypeTableName, pb), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, fs.FileMode(0666)) + if err != nil { + log.Errorf("err:%v", err) + return err + } + defer file.Close() + + _, err = file.Write(getFileHeader(pb)) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + err = tpl.Execute(file, args) + if err != nil { + log.Errorf("err:%v", err) + return err + } + + return nil +} diff --git a/codegen/path.go b/codegen/path.go index dec807f..96c9d62 100644 --- a/codegen/path.go +++ b/codegen/path.go @@ -8,6 +8,8 @@ const ( PathTypePbGo PathType = iota + 1 PathTypeGoMod PathTypeEditorconfig + PathTypeOrm + PathTypeTableName PathTypeState PathTypeStateTable @@ -27,6 +29,12 @@ func GetPath(t PathType, pb *PbPackage) string { case PathTypeEditorconfig: return filepath.Join(pb.ProjectRoot(), ".editorconfig") + case PathTypeOrm: + return filepath.Join(pb.ProjectRoot(), "orm.gen.go") + + case PathTypeTableName: + return filepath.Join(pb.ProjectRoot(), "table_name.gen.go") + case PathTypeState: return filepath.Join(pb.ProjectRoot(), "state") diff --git a/codegen/pbparse.go b/codegen/pbparse.go index 84763a7..d76076e 100644 --- a/codegen/pbparse.go +++ b/codegen/pbparse.go @@ -1,7 +1,7 @@ package codegen import ( - "fmt" + "bytes" "github.com/emicklei/proto" "github.com/lazygophers/codegen/state" "github.com/lazygophers/log" @@ -112,8 +112,60 @@ type PbEnum struct { enum *proto.Enum } +func (p *PbEnum) Enum() *proto.Enum { + return p.enum +} + +func (p *PbEnum) walk() { + +} + +func NewPbEnum(e *proto.Enum) *PbEnum { + p := &PbEnum{ + enum: e, + } + + return p +} + type PbRPC struct { - rpc *proto.RPC + rpc *proto.RPC + Name string +} + +func (p *PbRPC) RPC() *proto.RPC { + return p.rpc +} + +func (p *PbRPC) walk() { +} + +func NewPbRPC(rpc *proto.RPC) *PbRPC { + p := &PbRPC{ + Name: rpc.Name, + rpc: rpc, + } + + return p +} + +type PbService struct { + service *proto.Service +} + +func (p *PbService) Service() *proto.Service { + return p.service +} + +func (p *PbService) walk() { +} + +func NewPbService(service *proto.Service) *PbService { + p := &PbService{ + service: service, + } + + return p } type PbNormalField struct { @@ -188,6 +240,18 @@ func (p *PbMessage) Message() *proto.Message { } func (p *PbMessage) IsTable() bool { + if strings.HasPrefix(p.Name, "Model") { + if !strings.Contains(p.Name, "_") { + return true + } + } + + // TODO: 允许通过对注释的解析判断时候是表 + + return false +} + +func (p *PbMessage) NeedOrm() bool { if strings.HasPrefix(p.Name, "Model") { return true } @@ -238,12 +302,22 @@ type PbPackage struct { proto *proto.Proto - messages map[string]*PbMessage - enums map[string]*PbEnum - rpcs map[string]*PbRPC - options map[string]*PbOption + messages []*PbMessage + enums []*PbEnum + services []*PbService + rpcs []*PbRPC + options []*PbOption + + msgMap map[string]*PbMessage + enumMap map[string]*PbEnum + serviceMap map[string]*PbService + rpcMap map[string]*PbRPC + optionMap map[string]*PbOption + RawGoPackage string PackageName string + + ProtoBuffer string } func (p *PbPackage) ProtoFilePath() string { @@ -251,7 +325,7 @@ func (p *PbPackage) ProtoFilePath() string { } func (p *PbPackage) ProtoFileName() string { - return filepath.Base(p.protoFilePath) + return p.proto.Filename } func (p *PbPackage) Proto() *proto.Proto { @@ -268,11 +342,9 @@ func (p *PbPackage) ProjectRoot() string { func (p *PbPackage) GoPackage() string { if state.Config.GoModulePrefix != "" { - return fmt.Sprintf("%s/%s", - strings.TrimSuffix(state.Config.GoModulePrefix, "/"), - strings.TrimPrefix(p.RawGoPackage, "/")) + return filepath.ToSlash(filepath.Join(state.Config.GoModulePrefix, p.RawGoPackage)) } else { - return p.RawGoPackage + return filepath.ToSlash(filepath.Join(p.RawGoPackage)) } } @@ -282,72 +354,115 @@ func (p *PbPackage) GoPackageName() string { func (p *PbPackage) Messages() []*PbMessage { var messages []*PbMessage - for _, m := range p.messages { + for _, m := range p.msgMap { messages = append(messages, m) } return messages } -func (p *PbPackage) GetParent(v proto.Visitee) string { - var nameList []string +func (p *PbPackage) GetMessage(msg string) *PbMessage { + return p.msgMap[msg] +} + +func (p *PbPackage) Enums() []*PbEnum { + return p.enums +} + +func (p *PbPackage) GetEnum(e string) *PbEnum { + return p.enumMap[e] +} + +func (p *PbPackage) RPCs() []*PbRPC { + return p.rpcs +} + +func (p *PbPackage) GetRPC(rpc string) *PbRPC { + return p.rpcMap[rpc] +} + +func (p *PbPackage) Services() []*PbService { + return p.services +} + +func (p *PbPackage) GetService(s string) *PbService { + return p.serviceMap[s] +} + +func (p *PbPackage) GetMessageFullName(e *proto.Message) string { + var names []string - for v != nil { - m := p.messages[fmt.Sprintf("%p", v)] + var walk func(m *proto.Message) + walk = func(m *proto.Message) { if m == nil { - break + return } - nameList = append(nameList, m.Name) + names = append(names, m.Name) - vv := &ProtoVisitor{} - v.Accept(vv) - if len(vv.msgList) == len(p.messages) || len(vv.msgList) != 1 { - // 到顶层了 - break + if m.Parent == nil { + return } - x := vv.msgList[0] - v = x.Parent - } - candy.Reverse(nameList) - return strings.Join(nameList, ".") -} + switch x := m.Parent.(type) { + case *proto.Message: + walk(x) -func (p *PbPackage) GetMessageFullName(e *proto.Message) string { - parent := p.GetParent(e.Parent) - var fullName string - if parent != "" { - fullName = fmt.Sprintf("%s.%s", parent, e.Name) - } else { - fullName = e.Name + case *proto.Proto: + // do nothing + + default: + log.Warnf("unknown parent type:%T", x) + } } - return fullName + + walk(e) + + return strings.Join(candy.Reverse(names), "_") } func (p *PbPackage) Walk() { proto.Walk(p.proto, proto.WithMessage(func(m *proto.Message) { + m.Name = p.GetMessageFullName(m) + log.Infof("message:%v", m.Name) pterm.Info.Printfln("find message:%s", m.Name) - p.messages[p.GetMessageFullName(m)] = NewPbMessage(m) - p.messages[m.Name].walk() + p.msgMap[m.Name] = NewPbMessage(m) + p.msgMap[m.Name].walk() + p.messages = append(p.messages, p.msgMap[m.Name]) + }), + proto.WithService(func(s *proto.Service) { + log.Infof("service:%v", s.Name) + pterm.Info.Printfln("find service:%s", s.Name) + + p.serviceMap[s.Name] = NewPbService(s) + p.serviceMap[s.Name].walk() + p.services = append(p.services, p.serviceMap[s.Name]) + }), + proto.WithRPC(func(r *proto.RPC) { + log.Infof("rpc:%v", r.Name) + pterm.Info.Printfln("find rpc:%s", r.Name) + + p.rpcMap[r.Name] = NewPbRPC(r) + p.rpcMap[r.Name].walk() + p.rpcs = append(p.rpcs, p.rpcMap[r.Name]) }), - //proto.WithService(func(s *proto.Service) { - // log.Infof("service:%v", s.Name) - //}), - //proto.WithRPC(func(r *proto.RPC) { - // log.Infof("rpc:%v", r.Name) - //}), proto.WithEnum(func(e *proto.Enum) { log.Infof("enum:%v", e.Name) + pterm.Info.Printfln("find enum:%s", e.Name) + + p.enumMap[e.Name] = NewPbEnum(e) + p.enumMap[e.Name].walk() + p.enums = append(p.enums, p.enumMap[e.Name]) }), proto.WithOption(func(option *proto.Option) { log.Infof("option:%v", option.Name) pterm.Info.Printfln("find option:%s", option.Name) - p.options[option.Name] = NewPbOption(option) - p.options[option.Name].walk() + p.optionMap[option.Name] = NewPbOption(option) + p.optionMap[option.Name].walk() + p.options = append(p.options, p.optionMap[option.Name]) }), proto.WithPackage(func(pp *proto.Package) { log.Infof("package:%v", pp.Name) @@ -357,7 +472,7 @@ func (p *PbPackage) Walk() { }), ) - if o, ok := p.options["go_package"]; ok { + if o, ok := p.optionMap["go_package"]; ok { p.RawGoPackage = o.Value } } @@ -366,22 +481,30 @@ func NewPbPackage(protoFilePath string, p *proto.Proto) *PbPackage { return &PbPackage{ protoFilePath: protoFilePath, proto: p, - messages: map[string]*PbMessage{}, - enums: map[string]*PbEnum{}, - rpcs: map[string]*PbRPC{}, - options: map[string]*PbOption{}, + messages: nil, + enums: nil, + services: nil, + rpcs: nil, + options: nil, + msgMap: map[string]*PbMessage{}, + enumMap: map[string]*PbEnum{}, + serviceMap: map[string]*PbService{}, + rpcMap: map[string]*PbRPC{}, + optionMap: map[string]*PbOption{}, + RawGoPackage: "", + PackageName: "", + ProtoBuffer: "", } } -func ParseProto(path string) (*PbPackage, error) { - pbFile, err := os.Open(path) +func ParseProto(path string, cacheFiles ...bool) (*PbPackage, error) { + protoBuffer, err := os.ReadFile(path) if err != nil { log.Errorf("err:%v", err) return nil, err } - defer pbFile.Close() - pb, err := proto.NewParser(pbFile).Parse() + pb, err := proto.NewParser(bytes.NewBuffer(protoBuffer)).Parse() if err != nil { log.Errorf("err:%v", err) return nil, err @@ -389,6 +512,12 @@ func ParseProto(path string) (*PbPackage, error) { p := NewPbPackage(path, pb) + if len(cacheFiles) > 0 { + if cacheFiles[0] { + p.ProtoBuffer = string(protoBuffer) + } + } + p.Walk() return p, nil diff --git a/codegen/resource.go b/codegen/resource.go index 173f16f..5f41d9f 100644 --- a/codegen/resource.go +++ b/codegen/resource.go @@ -2,12 +2,14 @@ package codegen import ( "embed" + "fmt" "github.com/lazygophers/codegen/state" "github.com/lazygophers/log" "github.com/lazygophers/utils/anyx" "github.com/lazygophers/utils/candy" "github.com/lazygophers/utils/stringx" "github.com/pterm/pterm" + "go.uber.org/atomic" "os" "strings" "text/template" @@ -21,6 +23,13 @@ type TemplateType uint8 const ( TemplateTypeEditorconfig TemplateType = iota + 1 + TemplateTypeOrm + TemplateTypeTableName + + TemplateTypeProtoService + TemplateTypeProtoRpcName + TemplateTypeProtoRpcReq + TemplateTypeProtoRpcResp TemplateTypeStateTable TemplateTypeStateConf @@ -28,7 +37,7 @@ const ( TemplateTypeStateState ) -func GetTemplate(t TemplateType) (tpl *template.Template, err error) { +func GetTemplate(t TemplateType, args ...string) (tpl *template.Template, err error) { var systemPath, embedPath string switch t { @@ -36,6 +45,45 @@ func GetTemplate(t TemplateType) (tpl *template.Template, err error) { systemPath = state.Config.Template.Editorconfig embedPath = "template/.editorconfig" + case TemplateTypeOrm: + systemPath = state.Config.Template.Orm + embedPath = "template/orm.gtpl" + + case TemplateTypeTableName: + systemPath = state.Config.Template.TableName + embedPath = "template/table_name.gtpl" + + case TemplateTypeProtoService: + systemPath = state.Config.Template.Proto.Service + embedPath = "template/proto/rpc_service.gtpl" + + case TemplateTypeProtoRpcName: + if len(args) != 1 { + panic("Must provide") + } + if v, ok := state.Config.Template.Proto.Rpc[args[0]]; ok && v != nil { + systemPath = v.Name + } + embedPath = fmt.Sprintf("template/proto/%s.name.rpc.gtpl", args[0]) + + case TemplateTypeProtoRpcReq: + if len(args) != 1 { + panic("Must provide") + } + if v, ok := state.Config.Template.Proto.Rpc[args[0]]; ok && v != nil { + systemPath = v.Req + } + embedPath = fmt.Sprintf("template/proto/%s.req.rpc.gtpl", args[0]) + + case TemplateTypeProtoRpcResp: + if len(args) != 1 { + panic("Must provide") + } + if v, ok := state.Config.Template.Proto.Rpc[args[0]]; ok && v != nil { + systemPath = v.Resp + } + embedPath = fmt.Sprintf("template/proto/%s.resp.rpc.gtpl", args[0]) + case TemplateTypeStateTable: systemPath = state.Config.Template.Table embedPath = "template/state/table.gtpl" @@ -103,6 +151,32 @@ func GetTemplate(t TemplateType) (tpl *template.Template, err error) { return tpl, nil } +var ( + counters = map[string]*atomic.Int64{} +) + +func IncrWithKey(key string, def int64) int64 { + if v, ok := counters[key]; ok { + return v.Inc() + } + + v := atomic.NewInt64(0) + counters[key] = v + + return def +} + +func DecrWithKey(key string, def int64) int64 { + if v, ok := counters[key]; ok { + return v.Dec() + } + + v := atomic.NewInt64(0) + counters[key] = v + + return def +} + var DefaultTemplateFunc = template.FuncMap{ "ToCamel": stringx.ToCamel, "ToSnake": stringx.ToSnake, @@ -133,4 +207,7 @@ var DefaultTemplateFunc = template.FuncMap{ "First": candy.First[string], "Last": candy.Last[string], "Contains": candy.Contains[string], + + "IncrKey": IncrWithKey, + "DecrKey": DecrWithKey, } diff --git a/codegen/template/orm.gtpl b/codegen/template/orm.gtpl new file mode 100644 index 0000000..9bb2ebd --- /dev/null +++ b/codegen/template/orm.gtpl @@ -0,0 +1,15 @@ +package {{ .PB.GoPackageName }} + +import ( + "database/sql/driver" + "github.com/lazygophers/utils" +) + +{{ range $key, $value := .Models }}func (m *{{ $value }}) Scan(value interface{}) error { + return utils.Scan(m, value) +} + +func (m *{{ $value }}) Value() (driver.Value, error) { + return utils.Value(m) +} +{{ end }} diff --git a/codegen/template/proto/add.name.rpc.gtpl b/codegen/template/proto/add.name.rpc.gtpl new file mode 100644 index 0000000..28efe4b --- /dev/null +++ b/codegen/template/proto/add.name.rpc.gtpl @@ -0,0 +1 @@ +Add{{ TrimPrefix .Model "Model" }}{{ ToCamel .Role }} diff --git a/codegen/template/proto/add.req.rpc.gtpl b/codegen/template/proto/add.req.rpc.gtpl new file mode 100644 index 0000000..5d401ff --- /dev/null +++ b/codegen/template/proto/add.req.rpc.gtpl @@ -0,0 +1,4 @@ +message {{ .RequestType }} { + // @validate: required + {{ .Model }} {{ ToSnake (TrimPrefix .Model "Model") }} = 1; +} diff --git a/codegen/template/proto/add.resp.rpc.gtpl b/codegen/template/proto/add.resp.rpc.gtpl new file mode 100644 index 0000000..78cf09d --- /dev/null +++ b/codegen/template/proto/add.resp.rpc.gtpl @@ -0,0 +1,3 @@ +message {{ .ResponseType }} { + {{ .Model }} {{ ToSnake (TrimPrefix .Model "Model") }} = 1; +} diff --git a/codegen/template/proto/del.name.rpc.gtpl b/codegen/template/proto/del.name.rpc.gtpl new file mode 100644 index 0000000..58d3d51 --- /dev/null +++ b/codegen/template/proto/del.name.rpc.gtpl @@ -0,0 +1 @@ +Del{{ TrimPrefix .Model "Model" }}{{ ToCamel .Role }} diff --git a/codegen/template/proto/del.req.rpc.gtpl b/codegen/template/proto/del.req.rpc.gtpl new file mode 100644 index 0000000..c47b9d8 --- /dev/null +++ b/codegen/template/proto/del.req.rpc.gtpl @@ -0,0 +1,4 @@ +message {{ .RequestType }} { +{{with .PprimaryKey }} // @validate: required + {{ $.PprimaryKeyType}} {{ $.PprimaryKey }} = 1;{{end}} +} diff --git a/codegen/template/proto/del.resp.rpc.gtpl b/codegen/template/proto/del.resp.rpc.gtpl new file mode 100644 index 0000000..c22d378 --- /dev/null +++ b/codegen/template/proto/del.resp.rpc.gtpl @@ -0,0 +1,2 @@ +message {{ .ResponseType }} { +} diff --git a/codegen/template/proto/get.name.rpc.gtpl b/codegen/template/proto/get.name.rpc.gtpl new file mode 100644 index 0000000..330e787 --- /dev/null +++ b/codegen/template/proto/get.name.rpc.gtpl @@ -0,0 +1 @@ +Get{{ TrimPrefix .Model "Model" }}{{ ToCamel .Role }} diff --git a/codegen/template/proto/get.req.rpc.gtpl b/codegen/template/proto/get.req.rpc.gtpl new file mode 100644 index 0000000..c47b9d8 --- /dev/null +++ b/codegen/template/proto/get.req.rpc.gtpl @@ -0,0 +1,4 @@ +message {{ .RequestType }} { +{{with .PprimaryKey }} // @validate: required + {{ $.PprimaryKeyType}} {{ $.PprimaryKey }} = 1;{{end}} +} diff --git a/codegen/template/proto/get.resp.rpc.gtpl b/codegen/template/proto/get.resp.rpc.gtpl new file mode 100644 index 0000000..78cf09d --- /dev/null +++ b/codegen/template/proto/get.resp.rpc.gtpl @@ -0,0 +1,3 @@ +message {{ .ResponseType }} { + {{ .Model }} {{ ToSnake (TrimPrefix .Model "Model") }} = 1; +} diff --git a/codegen/template/proto/list.name.rpc.gtpl b/codegen/template/proto/list.name.rpc.gtpl new file mode 100644 index 0000000..0df70d6 --- /dev/null +++ b/codegen/template/proto/list.name.rpc.gtpl @@ -0,0 +1 @@ +List{{ TrimPrefix .Model "Model" }}{{ ToCamel .Role }} diff --git a/codegen/template/proto/list.req.rpc.gtpl b/codegen/template/proto/list.req.rpc.gtpl new file mode 100644 index 0000000..0447b93 --- /dev/null +++ b/codegen/template/proto/list.req.rpc.gtpl @@ -0,0 +1,10 @@ +message {{ .RequestType }} { + enum ListOption { + ListOptionNil = {{ IncrKey $.RequestType 0 }};{{ range $key, $value := .ListOptions }}{{ with $value }} + // @type: {{ $value }}{{ end }} + ListOption{{ ToCamel $key }} = {{ IncrKey $.RequestType 0 }};{{ end }} + } + + // @validate: required + core.ListOption list_option = 1; +} diff --git a/codegen/template/proto/list.resp.rpc.gtpl b/codegen/template/proto/list.resp.rpc.gtpl new file mode 100644 index 0000000..4cc67e0 --- /dev/null +++ b/codegen/template/proto/list.resp.rpc.gtpl @@ -0,0 +1,4 @@ +message {{ .ResponseType }} { + core.Paginate paginate = 1; + {{ .Model }} {{ ToSnake (TrimPrefix .Model "Model") }} = 2; +} diff --git a/codegen/template/proto/rpc_service.gtpl b/codegen/template/proto/rpc_service.gtpl new file mode 100644 index 0000000..9f3d111 --- /dev/null +++ b/codegen/template/proto/rpc_service.gtpl @@ -0,0 +1,8 @@ +{{ with .Role }} + // @role: {{ $.Role }} {{end}}{{ with .Url }} + // @url: {{ $.Url }} {{end}}{{ with .GenTo }} + // @gen: {{ $.GenTo }} {{end}}{{ with .Model }} + // @model: {{ $.Model }} {{end}}{{ with .Action }} + // @action: {{ $.Action }} {{end}} + rpc {{ .RpcName }} ({{ .RequestType }}) returns ({{ .ResponseType }}) { + }; diff --git a/codegen/template/proto/set.name.rpc.gtpl b/codegen/template/proto/set.name.rpc.gtpl new file mode 100644 index 0000000..f966888 --- /dev/null +++ b/codegen/template/proto/set.name.rpc.gtpl @@ -0,0 +1 @@ +Set{{ TrimPrefix .Model "Model" }}{{ ToCamel .Role }} diff --git a/codegen/template/proto/set.req.rpc.gtpl b/codegen/template/proto/set.req.rpc.gtpl new file mode 100644 index 0000000..5d401ff --- /dev/null +++ b/codegen/template/proto/set.req.rpc.gtpl @@ -0,0 +1,4 @@ +message {{ .RequestType }} { + // @validate: required + {{ .Model }} {{ ToSnake (TrimPrefix .Model "Model") }} = 1; +} diff --git a/codegen/template/proto/set.resp.rpc.gtpl b/codegen/template/proto/set.resp.rpc.gtpl new file mode 100644 index 0000000..78cf09d --- /dev/null +++ b/codegen/template/proto/set.resp.rpc.gtpl @@ -0,0 +1,3 @@ +message {{ .ResponseType }} { + {{ .Model }} {{ ToSnake (TrimPrefix .Model "Model") }} = 1; +} diff --git a/codegen/template/proto/update.name.rpc.gtpl b/codegen/template/proto/update.name.rpc.gtpl new file mode 100644 index 0000000..b786485 --- /dev/null +++ b/codegen/template/proto/update.name.rpc.gtpl @@ -0,0 +1 @@ +Update{{ TrimPrefix .Model "Model" }}{{ ToCamel .Role }} diff --git a/codegen/template/proto/update.req.rpc.gtpl b/codegen/template/proto/update.req.rpc.gtpl new file mode 100644 index 0000000..c47b9d8 --- /dev/null +++ b/codegen/template/proto/update.req.rpc.gtpl @@ -0,0 +1,4 @@ +message {{ .RequestType }} { +{{with .PprimaryKey }} // @validate: required + {{ $.PprimaryKeyType}} {{ $.PprimaryKey }} = 1;{{end}} +} diff --git a/codegen/template/proto/update.resp.rpc.gtpl b/codegen/template/proto/update.resp.rpc.gtpl new file mode 100644 index 0000000..c22d378 --- /dev/null +++ b/codegen/template/proto/update.resp.rpc.gtpl @@ -0,0 +1,2 @@ +message {{ .ResponseType }} { +} diff --git a/codegen/template/state/table.gtpl b/codegen/template/state/table.gtpl index 52a384f..0c0beb3 100644 --- a/codegen/template/state/table.gtpl +++ b/codegen/template/state/table.gtpl @@ -1,13 +1,16 @@ package state -{{ if gt (len .GoImports) 0 }}import ( -{{ range $key, $value := .GoImports }} "{{ $value }}" -{{ end }}){{ end }} +import ( + {{ .PB.GoPackage }} + "github.com/lazygophers/log" + "github.com/lazygophers/utils/common" + "github.com/lazygophers/utils/db" +) var ( _db *db.Client -{{ range $key, $value := .Models}} {{TrimPrefix $value "Model"}} *db.Model[{{ $.PB.GoPackageName }}.{{ $value }}] +{{ range $key, $value := .Models }} {{TrimPrefix $value "Model"}} *db.Model[{{ $.PB.GoPackageName }}.{{ $value }}] {{ end }}) func ConnectDatebase() (err error) { diff --git a/codegen/template/table_name.gtpl b/codegen/template/table_name.gtpl new file mode 100644 index 0000000..99c814e --- /dev/null +++ b/codegen/template/table_name.gtpl @@ -0,0 +1,6 @@ +package {{ .PB.GoPackageName }} + +{{ range $key, $value := .Models }}func ({{ $value }}) TableName() string { + return "{{ $.PB.GoPackageName }}_{{ ToSnake (TrimPrefix $value "Model") }}" +} +{{ end }} diff --git a/example.codegen.cfg.yaml b/example.codegen.cfg.yaml index bf31514..8ff2f7a 100644 --- a/example.codegen.cfg.yaml +++ b/example.codegen.cfg.yaml @@ -8,8 +8,42 @@ go_module_prefix: "" + # orm.go 模板文件路径,用于 gorm 的相关数据生成 + orm: "" + # table_name.go 模板文件路径,用于 gorm 的相关表名 + table_name: "" + + # .proto 文件相关的模板文件 + proto: + # rpc 服务的模板文件路径 + service: "proto.service 的模板文件路径" + + # rpc 相关的 + rpc: + # key → action value → 模版文件路径 + add: + name: "rpc 名字的生成模板" + req: "message.request 生成模板" + resp: "messsage.response 生成模板" + set: + name: "rpc 名字的生成模板" + req: "message.request 生成模板" + resp: "messsage.response 生成模板" + update: + name: "rpc 名字的生成模板" + req: "message.request 生成模板" + resp: "messsage.response 生成模板" + del: + name: "rpc 名字的生成模板" + req: "message.request 生成模板" + resp: "messsage.response 生成模板" + # 如果 key 为 list,会触发筛选项的生辰 + list: + name: "rpc 名字的生成模板" + req: "message.request 生成模板" + resp: "messsage.response 生成模板" # state/table.go 模板文件路径 table: "" diff --git a/go.mod b/go.mod index 6440917..4d7a3e2 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( atomicgo.dev/cursor v0.2.0 // indirect atomicgo.dev/keyboard v0.2.9 // indirect atomicgo.dev/schedule v0.1.0 // indirect - github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic v1.11.7 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect diff --git a/go.sum b/go.sum index 702bda8..58c388e 100644 --- a/go.sum +++ b/go.sum @@ -16,8 +16,7 @@ github.com/MarvinJWendt/testza v0.4.2/go.mod h1:mSdhXiKH8sg/gQehJ63bINcCKp7RtYew github.com/MarvinJWendt/testza v0.5.2 h1:53KDo64C1z/h/d/stCYCPY69bt/OSwjq5KpFNwi+zB4= github.com/MarvinJWendt/testza v0.5.2/go.mod h1:xu53QFE5sCdjtMCKk8YMQ2MnymimEctc4n3EjyIYvEY= github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= -github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= -github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic v1.11.7 h1:k/l9p1hZpNIMJSk37wL9ltkcpqLfIho1vYthi4xT2t4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= diff --git a/state/config.go b/state/config.go index 08eb14a..432b377 100644 --- a/state/config.go +++ b/state/config.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/lazygophers/log" "github.com/lazygophers/utils/app" + "github.com/lazygophers/utils/defaults" "github.com/lazygophers/utils/json" "github.com/lazygophers/utils/osx" "github.com/lazygophers/utils/runtime" @@ -16,16 +17,35 @@ import ( "strings" ) +type CfgStyle struct { + Go string `json:"go,omitempty" yaml:"go,omitempty" toml:"go,omitempty" default:"fiber"` + + ListPagination string `json:"list_pagination,omitempty" yaml:"list_pagination,omitempty" toml:"list_pagination,omitempty" default:"offset"` +} + +type CfgProtoRpc struct { + Name string `json:"name,omitempty" yaml:"name,omitempty" toml:"name,omitempty"` + Req string `json:"req,omitempty" yaml:"req,omitempty" toml:"req,omitempty"` + Resp string `json:"resp,omitempty" yaml:"resp,omitempty" toml:"resp,omitempty"` +} + +type CfgProto struct { + Rpc map[string]*CfgProtoRpc `json:"rpc,omitempty" yaml:"rpc,omitempty" toml:"rpc,omitempty"` + + Service string `json:"service,omitempty" yaml:"service,omitempty" toml:"service,omitempty"` +} + type CfgTemplate struct { Editorconfig string `json:"editorconfig,omitempty" yaml:"editorconfig,omitempty" toml:"editorconfig,omitempty"` - Table string `json:"table,omitempty" yaml:"table,omitempty" toml:"table,omitempty"` - Conf string `json:"conf,omitempty" yaml:"conf,omitempty" toml:"conf,omitempty"` - Cache string `json:"cache,omitempty" yaml:"cache,omitempty" toml:"cache,omitempty"` - State string `json:"state,omitempty" yaml:"state,omitempty" toml:"state,omitempty"` -} + Orm string `json:"orm,omitempty" yaml:"orm,omitempty" toml:"orm,omitempty"` + TableName string `json:"table_name,omitempty" yaml:"table_name,omitempty" toml:"table_name,omitempty"` -func (p *CfgTemplate) apply() { + Proto *CfgProto `json:"proto,omitempty" yaml:"proto,omitempty" toml:"proto,omitempty"` + Table string `json:"table,omitempty" yaml:"table,omitempty" toml:"table,omitempty"` + Conf string `json:"conf,omitempty" yaml:"conf,omitempty" toml:"conf,omitempty"` + Cache string `json:"cache,omitempty" yaml:"cache,omitempty" toml:"cache,omitempty"` + State string `json:"state,omitempty" yaml:"state,omitempty" toml:"state,omitempty"` } type CfgTables struct { @@ -63,6 +83,8 @@ type Cfg struct { OutputPath string `json:"output_path,omitempty" yaml:"output_path,omitempty" toml:"output_path,omitempty"` + Style *CfgStyle `json:"style,omitempty" yaml:"style,omitempty" toml:"style,omitempty"` + Template *CfgTemplate `json:"template,omitempty" yaml:"template,omitempty" toml:"template,omitempty"` // 对于原始数据,key 为 tag 名。value.key 为字段名,value.value 为 tag 内容 @@ -100,12 +122,25 @@ func (p *Cfg) apply() (err error) { if p.Template == nil { p.Template = new(CfgTemplate) } - p.Template.apply() + if p.Template.Proto == nil { + p.Template.Proto = new(CfgProto) + } + if len(p.Template.Proto.Rpc) == 0 { + p.Template.Proto.Rpc = make(map[string]*CfgProtoRpc) + } if p.Tables == nil { p.Tables = new(CfgTables) } - p.Tables.apply() + + if p.Style == nil { + p.Style = new(CfgStyle) + } + err = defaults.SetDefaults(p.Style) + if err != nil { + log.Errorf("err:%v", err) + return err + } // NOTE: struct 标签 {