-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.go
194 lines (179 loc) · 5.2 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
// This is free and unencumbered software released into the public
// domain. For more information, see <http://unlicense.org> or the
// accompanying UNLICENSE file.
package main
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/nelsam/hel/mocks"
"github.com/nelsam/hel/packages"
"github.com/nelsam/hel/types"
"github.com/spf13/cobra"
)
var (
cmd *cobra.Command
goimportsPath string
)
func init() {
output, err := exec.Command("which", "goimports").Output()
if err != nil {
fmt.Println("Could not locate goimports: ", err.Error())
fmt.Println("If goimports is not installed, please install it somewhere in your path. " +
"See https://godoc.org/golang.org/x/tools/cmd/goimports.")
os.Exit(1)
}
goimportsPath = strings.TrimSpace(string(output))
cmd = &cobra.Command{
Use: "hel",
Short: "A mock generator for Go",
Long: "hel is a simple mock generator. The origin of the name is the Norse goddess, Hel, " +
"who guards over the souls of those unworthy to enter Valhalla. You can probably " +
"guess how much I like mocks.",
Run: func(cmd *cobra.Command, args []string) {
if len(args) > 0 {
fmt.Print("Invalid usage. Help:\n\n")
cmd.HelpFunc()(nil, nil)
os.Exit(1)
}
packagePatterns, err := cmd.Flags().GetStringSlice("package")
if err != nil {
panic(err)
}
typePatterns, err := cmd.Flags().GetStringSlice("type")
if err != nil {
panic(err)
}
outputName, err := cmd.Flags().GetString("output")
if err != nil {
panic(err)
}
chanSize, err := cmd.Flags().GetInt("chan-size")
if err != nil {
panic(err)
}
blockingReturn, err := cmd.Flags().GetBool("blocking-return")
if err != nil {
panic(err)
}
noTestPkg, err := cmd.Flags().GetBool("no-test-package")
if err != nil {
panic(err)
}
fmt.Printf("Loading directories matching %s %v", pluralize(packagePatterns, "pattern", "patterns"), packagePatterns)
var dirList []packages.Dir
progress(func() {
dirList = packages.Load(packagePatterns...)
})
fmt.Print("\n")
fmt.Println("Found directories:")
for _, dir := range dirList {
fmt.Println(" " + dir.Path())
}
fmt.Print("\n")
fmt.Printf("Loading interface types in matching directories")
var typeDirs types.Dirs
progress(func() {
godirs := make([]types.GoDir, 0, len(dirList))
for _, dir := range dirList {
godirs = append(godirs, dir)
}
typeDirs = types.Load(godirs...).Filter(typePatterns...)
})
fmt.Print("\n\n")
fmt.Printf("Generating mocks in output file %s", outputName)
progress(func() {
for _, typeDir := range typeDirs {
mockPath, err := makeMocks(typeDir, outputName, chanSize, blockingReturn, !noTestPkg)
if err != nil {
panic(err)
}
if mockPath != "" {
if err = exec.Command(goimportsPath, "-w", mockPath).Run(); err != nil {
panic(err)
}
}
}
})
fmt.Print("\n")
},
}
cmd.Flags().StringSliceP("package", "p", []string{"."}, "The package(s) to generate mocks for.")
cmd.Flags().StringSliceP("type", "t", []string{}, "The type(s) to generate mocks for. If no types "+
"are passed in, all exported interface types will be generated.")
cmd.Flags().StringP("output", "o", "helheim_test.go", "The file to write generated mocks to. Since hel does "+
"not generate exported types, this file will be saved directly in all packages with generated mocks. "+
"Also note that, since the types are not exported, you will want the file to end in '_test.go'.")
cmd.Flags().IntP("chan-size", "s", 100, "The size of channels used for method calls.")
cmd.Flags().BoolP("blocking-return", "b", false, "Always block when returning from mock even if there is no return value.")
cmd.Flags().Bool("no-test-package", false, "Generate mocks in the primary package rather than in {pkg}_test")
}
func makeMocks(types types.Dir, fileName string, chanSize int, blockingReturn, useTestPkg bool) (filePath string, err error) {
mocks, err := mocks.Generate(types)
if err != nil {
return "", err
}
if len(mocks) == 0 {
return "", nil
}
mocks.SetBlockingReturn(blockingReturn)
if useTestPkg {
mocks.PrependLocalPackage(types.Package())
}
filePath = filepath.Join(types.Dir(), fileName)
f, err := os.Create(filePath)
if err != nil {
return "", err
}
defer f.Close()
testPkg := types.Package()
if useTestPkg {
testPkg += "_test"
}
return filePath, mocks.Output(testPkg, types.Dir(), chanSize, f)
}
func progress(f func()) {
stop, done := make(chan struct{}), make(chan struct{})
defer func() {
close(stop)
<-done
}()
go showProgress(stop, done)
f()
}
func showProgress(stop <-chan struct{}, done chan<- struct{}) {
defer close(done)
ticker := time.NewTicker(time.Second / 2)
defer ticker.Stop()
for {
select {
case <-ticker.C:
fmt.Print(".")
case <-stop:
return
}
}
}
type lengther interface {
Len() int
}
func pluralize(values interface{}, singular, plural string) string {
length := findLength(values)
if length == 1 {
return singular
}
return plural
}
func findLength(values interface{}) int {
if lengther, ok := values.(lengther); ok {
return lengther.Len()
}
return reflect.ValueOf(values).Len()
}
func main() {
cmd.Execute()
}