Skip to content

Commit

Permalink
Merge pull request #34 from lazygophers/luoxin
Browse files Browse the repository at this point in the history
针对表类型的orm过滤,使得更精准的命中需要生成orm相关代码的类型
  • Loading branch information
Luoxin authored Aug 29, 2024
2 parents e458aa8 + 6e58d17 commit d20fdad
Show file tree
Hide file tree
Showing 155 changed files with 432 additions and 62 deletions.
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ build: ## build
goreleaser build --clean --snapshot --single-target --config=debug.goreleaser.yaml
#GOVERSION=$(shell go version | awk '{print $$3;}') goreleaser --clean --snapshot --skip=publish,validate --timeout=24h

.PHONY: gen
gen: ## gen
codegen i18n tran -s ./state/localize/zh.yaml --generate-const=true --all-languages=false -l zh,zh-hans,en

.PHONY: tran
tran: ## tran
codegen i18n tran --generate-const=true -s ./state/localize/zh.yaml

.PHONY: lint
lint: ## lint go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
golangci-lint run -v
Expand Down
4 changes: 4 additions & 0 deletions cli/gen/gen_all.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ var GenAllHooks = []GenHook{
var allCmd = &cobra.Command{
Use: "all",
Aliases: []string{"a", "all-actions"},
PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) {
mergeStateFlags(cmd)
return nil
},
RunE: func(cmd *cobra.Command, args []string) (err error) {
for _, hook := range GenAllHooks {
err = hook(cmd, args)
Expand Down
19 changes: 12 additions & 7 deletions cli/gen/gen_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ var GenStateHook = []GenHook{
}

var stateCmd = &cobra.Command{
Use: "state",
Use: "state",
PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) {
mergeStateFlags(cmd)
return nil
},
RunE: runGenState,
}

func runGenState(cmd *cobra.Command, args []string) (err error) {
mergeStateFlags(cmd)

err = codegen.GenerateState(pb)
if err != nil {
log.Errorf("err:%v", err)
Expand Down Expand Up @@ -64,13 +66,16 @@ func mergeStateFlags(cmd *cobra.Command) {
} else if state.LazyConfig.Gen.Config != nil {
state.Config.State.Config = *state.LazyConfig.Gen.Config
}

mergeStateTableFlags(cmd)
}

func initStateFlags(cmd *cobra.Command) {
cmd.Flags().Bool("table", state.Config.State.Table, state.Localize(state.I18nTagCliGenStateFlagsTable))
cmd.Flags().Bool("cache", state.Config.State.Cache, state.Localize(state.I18nTagCliGenStateFlagsCache))
cmd.Flags().Bool("i18n", state.Config.State.I18n, state.Localize(state.I18nTagCliGenStateFlagsI18n))
cmd.Flags().Bool("config", state.Config.State.Config, state.Localize(state.I18nTagCliGenStateFlagsConfig))
cmd.PersistentFlags().Bool("table", state.Config.State.Table, state.Localize(state.I18nTagCliGenStateFlagsTable))
cmd.PersistentFlags().Bool("cache", state.Config.State.Cache, state.Localize(state.I18nTagCliGenStateFlagsCache))
cmd.PersistentFlags().Bool("i18n", state.Config.State.I18n, state.Localize(state.I18nTagCliGenStateFlagsI18n))
cmd.PersistentFlags().Bool("config", state.Config.State.Config, state.Localize(state.I18nTagCliGenStateFlagsConfig))
initStateTableFlags(cmd)
}

func initState() {
Expand Down
24 changes: 23 additions & 1 deletion cli/gen/gen_table.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package gen

import (
"github.com/lazygophers/codegen/cli/utils"
"github.com/lazygophers/codegen/codegen"
"github.com/lazygophers/codegen/state"
"github.com/lazygophers/log"
"github.com/spf13/cobra"
)

var tableCmd = &cobra.Command{
Use: "table",
Use: "table",
PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) {
mergeStateTableFlags(cmd)
return nil
},
RunE: runGenTable,
}

Expand Down Expand Up @@ -36,6 +41,10 @@ func runGenTable(cmd *cobra.Command, args []string) (err error) {

var stateTableCmd = &cobra.Command{
Use: "table",
PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) {
mergeStateTableFlags(cmd)
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
state.Config.State.Table = true
return runGenStateTable(cmd, args)
Expand All @@ -52,13 +61,26 @@ func runGenStateTable(cmd *cobra.Command, args []string) (err error) {
return nil
}

func mergeStateTableFlags(cmd *cobra.Command) {
if cmd.Flag("template-state-table").Changed {
state.Config.Template.State.Table = utils.GetString("template-state-table", cmd)
}
}

func initStateTableFlags(cmd *cobra.Command) {
cmd.PersistentFlags().String("template-state-table", state.Config.Template.State.Table, state.Localize(state.I18nTagCliGenTableFlagsTemplateStateTable))
}

func initTable() {
tableCmd.Short = state.Localize(state.I18nTagCliGenTableShort)
tableCmd.Long = state.Localize(state.I18nTagCliGenTableLong)

stateTableCmd.Short = state.Localize(state.I18nTagCliGenStateTableShort)
stateTableCmd.Long = state.Localize(state.I18nTagCliGenStateTableLong)

initStateTableFlags(stateTableCmd)
initStateTableFlags(tableCmd)

stateCmd.AddCommand(stateTableCmd)
genCmd.AddCommand(tableCmd)
}
2 changes: 1 addition & 1 deletion codegen/generate_add_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (p *AddRpcOption) ParseListOption(s string, msg *PbMessage) {
option = item

// 如果没填,会按照字段类型填充默认数据
if field, ok := msg.normalFields[option]; ok {
if field, ok := msg.normalFieldMap[option]; ok {
optionType = field.Type()
}

Expand Down
41 changes: 36 additions & 5 deletions codegen/generate_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/pterm/pterm"
"io/fs"
"os"
"strings"
)

func GenerateStateTable(pb *PbPackage) (err error) {
Expand Down Expand Up @@ -80,16 +81,46 @@ func GenerateOrm(pb *PbPackage) (err error) {
var models []string
candy.Each(pb.Messages(), func(message *PbMessage) {
// 先全部允许。实际使用的时候要考虑被 model 引用的场景
//if !message.NeedOrm() {
// return
//}
if !message.NeedOrm() {
log.Warnf("skip message %s, because it's not a model", message.FullName)
return
}

log.Infof("find orm object %s", message.FullName)

models = append(models, message.FullName)

// 是 Model 类的,所以它的第一层都需要orm相关的配置
candy.Each(message.NormalFields(), func(field *PbNormalField) {
switch field.Type() {
case "bool":
return
case "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64":
return
case "float", "double":
return
case "string":
return
}

if field.Type() != field.FullType() {
return
}

names := GetFullNames(field.Field())

for i := len(names); i >= 0; i-- {
if pb.GetMessage(strings.Join(names[i:], "_")) == nil {
continue
}

models = append(models, strings.Join(names[i:], "_"))

break
}
})
})

args["Models"] = models
args["Models"] = candy.Sort(candy.Unique(models))
}

// 生成 table.go
Expand Down
79 changes: 45 additions & 34 deletions codegen/pbparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,15 @@ func (p *PbNormalField) Field() *proto.NormalField {
}

func (p *PbNormalField) Type() string {
idx := strings.LastIndex(p.field.Type, ".")
if idx != -1 {
return p.field.Type[idx+1:]
} else {
return p.field.Type
}
}

func (p *PbNormalField) FullType() string {
return p.field.Type
}

Expand Down Expand Up @@ -443,16 +452,22 @@ func NewPbEnumField(f *proto.EnumField) *PbEnumField {
}

type PbMessage struct {
message *proto.Message
normalFields map[string]*PbNormalField
mapFields map[string]*PbMapField
enumFields map[string]*PbEnumField
FullName string
Name string
message *proto.Message

normalFieldMap map[string]*PbNormalField
mapFieldMap map[string]*PbMapField
enumFieldMap map[string]*PbEnumField

normalFieldList []*PbNormalField
mapFieldList []*PbMapField
enumFieldList []*PbEnumField

FullName string
Name string
}

func (p *PbMessage) PrimaryField() (pkField *PbNormalField) {
for _, field := range p.normalFields {
for _, field := range p.normalFieldMap {
// 先简单粗暴用 id 当作主键,后面再改
if field.Name == "id" {
pkField = field
Expand All @@ -464,30 +479,15 @@ func (p *PbMessage) PrimaryField() (pkField *PbNormalField) {
}

func (p *PbMessage) NormalFields() []*PbNormalField {
fields := make([]*PbNormalField, 0, len(p.normalFields))
for _, field := range p.normalFields {
fields = append(fields, field)
}

return fields
return p.normalFieldList
}

func (p *PbMessage) MapFields() []*PbMapField {
fields := make([]*PbMapField, 0, len(p.mapFields))
for _, field := range p.mapFields {
fields = append(fields, field)
}

return fields
return p.mapFieldList
}

func (p *PbMessage) EnumFields() []*PbEnumField {
fields := make([]*PbEnumField, 0, len(p.enumFields))
for _, field := range p.enumFields {
fields = append(fields, field)
}

return fields
return p.enumFieldList
}

func (p *PbMessage) Message() *proto.Message {
Expand Down Expand Up @@ -527,17 +527,20 @@ func (p *PbMessage) walk() {

for _, field := range visitor.mapFields {
pterm.Info.Printfln("find map field:%s in %s", field.Name, p.FullName)
p.mapFields[field.Name] = NewPbMapField(field)
p.mapFieldMap[field.Name] = NewPbMapField(field)
p.mapFieldList = append(p.mapFieldList, p.mapFieldMap[field.Name])
}

for _, field := range visitor.normalFields {
pterm.Info.Printfln("find normal field:%s in %s", field.Name, p.FullName)
p.normalFields[field.Name] = NewPbNormalField(field)
p.normalFieldMap[field.Name] = NewPbNormalField(field)
p.normalFieldList = append(p.normalFieldList, p.normalFieldMap[field.Name])
}

for _, field := range visitor.enumFields {
pterm.Info.Printfln("find enum field:%s in %s", field.Name, p.FullName)
p.enumFields[field.Name] = NewPbEnumField(field)
p.enumFieldMap[field.Name] = NewPbEnumField(field)
p.enumFieldList = append(p.enumFieldList, p.enumFieldMap[field.Name])
}
}
}
Expand All @@ -546,9 +549,9 @@ func NewPbMessage(m *proto.Message) *PbMessage {
p := &PbMessage{
message: m,

normalFields: map[string]*PbNormalField{},
mapFields: map[string]*PbMapField{},
enumFields: map[string]*PbEnumField{},
normalFieldMap: map[string]*PbNormalField{},
mapFieldMap: map[string]*PbMapField{},
enumFieldMap: map[string]*PbEnumField{},
}
p.walk()
return p
Expand Down Expand Up @@ -658,7 +661,7 @@ func (p *PbPackage) GetService(s string) *PbService {
return p.serviceMap[s]
}

func GetFullName(e proto.Visitee) string {
func GetFullNames(e proto.Visitee) []string {
var names []string

var walk func(m proto.Visitee)
Expand All @@ -684,7 +687,11 @@ func GetFullName(e proto.Visitee) string {
}

case *proto.Proto:
// do nothing
// do nothing

case *proto.NormalField:
names = append(names, x.Type)
walk(x.Parent)

default:
log.Panicf("unknown parent type:%T", x)
Expand All @@ -693,7 +700,11 @@ func GetFullName(e proto.Visitee) string {

walk(e)

return strings.Join(candy.Reverse(names), "_")
return candy.Reverse(names)
}

func GetFullName(e proto.Visitee) string {
return strings.Join(GetFullNames(e), "_")
}

func (p *PbPackage) Walk() {
Expand Down
Loading

0 comments on commit d20fdad

Please sign in to comment.