Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for archive mode #125

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,23 @@ go install go.uber.org/mock/mockgen@latest

## Running mockgen

`mockgen` has two modes of operation: source and reflect.
`mockgen` has three modes of operation: archive, source and reflect.

### Archive mode

Archive mode generates mock interfaces from a package archive
file (.a). It is enabled by using the -archive flag. An import
path and a comma-separated list of symbols should be provided
as a non-flag argument to the command.

Example:

```bash
# Build the package to a archive.
go build -o pkg.a database/sql/driver

mockgen -archive=pkg.a database/sql/driver Conn,Driver
```

### Source mode

Expand Down Expand Up @@ -66,6 +82,8 @@ The `mockgen` command is used to generate source code for a mock
class given a Go source file containing interfaces to be mocked.
It supports the following flags:

- `-archive`: A package archive file containing interfaces to be mocked.

- `-source`: A file containing interfaces to be mocked.

- `-destination`: A file to which to write the resulting source code. If you
Expand Down
55 changes: 55 additions & 0 deletions mockgen/archive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"fmt"
"go/token"
"go/types"
"os"

"go.uber.org/mock/mockgen/model"

"golang.org/x/tools/go/gcexportdata"
)

func archiveMode(importPath string, symbols []string, archive string) (*model.Package, error) {
f, err := os.Open(archive)
if err != nil {
return nil, err
}
defer f.Close()
r, err := gcexportdata.NewReader(f)
if err != nil {
return nil, fmt.Errorf("read export data %q: %v", archive, err)
}

fset := token.NewFileSet()
imports := make(map[string]*types.Package)
tp, err := gcexportdata.Read(r, fset, imports, importPath)
if err != nil {
return nil, err
}

pkg := &model.Package{
Name: tp.Name(),
PkgPath: tp.Path(),
Interfaces: make([]*model.Interface, 0, len(symbols)),
}
for _, name := range symbols {
m := tp.Scope().Lookup(name)
tn, ok := m.(*types.TypeName)
if !ok {
continue
}
ti, ok := tn.Type().Underlying().(*types.Interface)
if !ok {
continue
}
it, err := model.InterfaceFromGoTypesType(ti)
if err != nil {
return nil, err
}
it.Name = m.Name()
pkg.Interfaces = append(pkg.Interfaces, it)
}
return pkg, nil
}
44 changes: 34 additions & 10 deletions mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var (
)

var (
archive = flag.String("archive", "", "(archive mode) Input Go archive file; enables archive mode.")
source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.")
destination = flag.String("destination", "", "Output file; defaults to stdout.")
mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.")
Expand All @@ -66,9 +67,8 @@ var (
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
excludeInterfaces = flag.String("exclude_interfaces", "", "Comma-separated names of interfaces to be excluded")

debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
showVersion = flag.Bool("version", false, "Print version.")
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
showVersion = flag.Bool("version", false, "Print version.")
)

func main() {
Expand All @@ -83,15 +83,22 @@ func main() {
var pkg *model.Package
var err error
var packageName string
if *source != "" {

// Switch between modes
switch {
case *source != "": // source mode
pkg, err = sourceMode(*source)
} else {
if flag.NArg() != 2 {
usage()
log.Fatal("Expected exactly two arguments")
}
case *archive != "": // archive mode
checkArgs()
packageName = flag.Arg(0)
interfaces := strings.Split(flag.Arg(1), ",")
pkg, err = archiveMode(packageName, interfaces, *archive)

default: // reflect mode
checkArgs()
packageName = flag.Arg(0)
interfaces := strings.Split(flag.Arg(1), ",")

if packageName == "." {
dir, err := os.Getwd()
if err != nil {
Expand All @@ -104,6 +111,7 @@ func main() {
}
pkg, err = reflectMode(packageName, interfaces)
}

if err != nil {
log.Fatalf("Loading input failed: %v", err)
}
Expand Down Expand Up @@ -144,6 +152,8 @@ func main() {
g := new(generator)
if *source != "" {
g.filename = *source
} else if *archive != "" {
g.filename = *archive
} else {
g.srcPackage = packageName
g.srcInterfaces = flag.Arg(1)
Expand Down Expand Up @@ -219,12 +229,19 @@ func parseExcludeInterfaces(names string) map[string]struct{} {
return namesSet
}

func checkArgs() {
if flag.NArg() != 2 {
usage()
log.Fatal("Expected exactly two arguments")
}
}

func usage() {
_, _ = io.WriteString(os.Stderr, usageText)
flag.PrintDefaults()
}

const usageText = `mockgen has two modes of operation: source and reflect.
const usageText = `mockgen has three modes of operation: archive, source and reflect.

Source mode generates mock interfaces from a source file.
It is enabled by using the -source flag. Other flags that
Expand All @@ -239,6 +256,13 @@ comma-separated list of symbols.
Example:
mockgen database/sql/driver Conn,Driver

Archive mode generates mock interfaces from a package archive
file (.a). It is enabled by using the -archive flag and two
non-flag arguments: an import path, and a comma-separated
list of symbols.
Example:
mockgen -archive=pkg.a database/sql/driver Conn,Driver

`

type generator struct {
Expand Down
160 changes: 160 additions & 0 deletions mockgen/model/model_gotypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package model

import (
"fmt"
"go/types"
)

// InterfaceFromGoTypesType returns a pointer to an interface for the
// given interface type loaded from archive.
func InterfaceFromGoTypesType(it *types.Interface) (*Interface, error) {
intf := &Interface{}

for i := 0; i < it.NumMethods(); i++ {
mt := it.Method(i)
// Skip unexported methods.
if !mt.Exported() {
continue
}
m := &Method{
Name: mt.Name(),
}

var err error
m.In, m.Variadic, m.Out, err = funcArgsFromGoTypesType(mt.Type().(*types.Signature))
if err != nil {
return nil, fmt.Errorf("method %q: %w", mt.Name(), err)
}

intf.AddMethod(m)
}

return intf, nil
}

func funcArgsFromGoTypesType(t *types.Signature) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) {
nin := t.Params().Len()
if t.Variadic() {
nin--
}
for i := 0; i < nin; i++ {
p, err := parameterFromGoTypesType(t.Params().At(i), false)
if err != nil {
return nil, nil, nil, err
}
in = append(in, p)
}
if t.Variadic() {
p, err := parameterFromGoTypesType(t.Params().At(nin), true)
if err != nil {
return nil, nil, nil, err
}
variadic = p
}
for i := 0; i < t.Results().Len(); i++ {
p, err := parameterFromGoTypesType(t.Results().At(i), false)
if err != nil {
return nil, nil, nil, err
}
out = append(out, p)
}
return
}

func parameterFromGoTypesType(v *types.Var, variadic bool) (*Parameter, error) {
t := v.Type()
if variadic {
t = t.(*types.Slice).Elem()
}
tt, err := typeFromGoTypesType(t)
if err != nil {
return nil, err
}
return &Parameter{Name: v.Name(), Type: tt}, nil
}

func typeFromGoTypesType(t types.Type) (Type, error) {
if t, ok := t.(*types.Named); ok {
tn := t.Obj()
if tn.Pkg() == nil {
return PredeclaredType(tn.Name()), nil
}
return &NamedType{
Package: tn.Pkg().Path(),
Type: tn.Name(),
}, nil
}

// only unnamed or predeclared types after here

// Lots of types have element types. Let's do the parsing and error checking for all of them.
var elemType Type
if t, ok := t.(interface{ Elem() types.Type }); ok {
var err error
elemType, err = typeFromGoTypesType(t.Elem())
if err != nil {
return nil, err
}
}

switch t := t.(type) {
case *types.Array:
return &ArrayType{
Len: int(t.Len()),
Type: elemType,
}, nil
case *types.Basic:
return PredeclaredType(t.String()), nil
case *types.Chan:
var dir ChanDir
switch t.Dir() {
case types.RecvOnly:
dir = RecvDir
case types.SendOnly:
dir = SendDir
}
return &ChanType{
Dir: dir,
Type: elemType,
}, nil
case *types.Signature:
in, variadic, out, err := funcArgsFromGoTypesType(t)
if err != nil {
return nil, err
}
return &FuncType{
In: in,
Out: out,
Variadic: variadic,
}, nil
case *types.Interface:
if t.NumMethods() == 0 {
return PredeclaredType("interface{}"), nil
}
case *types.Map:
kt, err := typeFromGoTypesType(t.Key())
if err != nil {
return nil, err
}
return &MapType{
Key: kt,
Value: elemType,
}, nil
case *types.Pointer:
return &PointerType{
Type: elemType,
}, nil
case *types.Slice:
return &ArrayType{
Len: -1,
Type: elemType,
}, nil
case *types.Struct:
if t.NumFields() == 0 {
return PredeclaredType("struct{}"), nil
}
// TODO: UnsafePointer
}

return nil, fmt.Errorf("can't yet turn %v (%T) into a model.Type", t.String(), t)
}