Skip to content

Commit

Permalink
Adds initial generics support
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanmoran committed Sep 4, 2024
1 parent fd2b3d4 commit dd7d6f7
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 41 deletions.
32 changes: 32 additions & 0 deletions acceptance/fixtures/fakes/generic_interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package fakes

import (
"sync"

"github.com/ryanmoran/faux/acceptance/fixtures"
)

type GenericInterface[T comparable, S comparable] struct {
SomeMethodCall struct {
mutex sync.Mutex
CallCount int
Receives struct {
MapTS map[T]S
}
Returns struct {
ResultIntError fixtures.Result[int, error]
}
Stub func(map[T]S) fixtures.Result[int, error]
}
}

func (f *GenericInterface[T, S]) SomeMethod(param1 map[T]S) fixtures.Result[int, error] {
f.SomeMethodCall.mutex.Lock()
defer f.SomeMethodCall.mutex.Unlock()
f.SomeMethodCall.CallCount++
f.SomeMethodCall.Receives.MapTS = param1
if f.SomeMethodCall.Stub != nil {
return f.SomeMethodCall.Stub(param1)
}
return f.SomeMethodCall.Returns.ResultIntError
}
9 changes: 9 additions & 0 deletions acceptance/fixtures/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ type NamedInterface interface {
SomeMethod(someParam *bytes.Buffer) (someResult io.Reader)
}

type Result[T, E any] struct {
Value T
Error E
}

type GenericInterface[T, S comparable] interface {
SomeMethod(map[T]S) Result[int, error]
}

type BurntSushiParser struct {
Key toml.Key
}
1 change: 1 addition & 0 deletions acceptance/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ var _ = Describe("faux", func() {
Entry("variadic", "variadic_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "VariadicInterface"),
Entry("functions", "function_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "FunctionInterface"),
Entry("name", "named_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "NamedInterface", "--name", "SomeNamedInterface"),
Entry("generic", "generic_interface.go", "--file", "./fixtures/interfaces.go", "--interface", "GenericInterface"),
)

Context("when the source file is provided via an environment variable", func() {
Expand Down
13 changes: 12 additions & 1 deletion parsing/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,23 @@ import (
type Argument struct {
Name string
Type types.Type
TypeArgs []types.Type
Variadic bool
Package string
}

func NewArgument(v *types.Var, variadic bool) Argument {
var pkg string
var (
pkg string
typeArgs []types.Type
)

if t, ok := v.Type().(*types.Named); ok {
targs := t.TypeArgs()
for i := 0; i < targs.Len(); i++ {
typeArgs = append(typeArgs, targs.At(i))
}

if t.Obj().Pkg() != nil {
pkg = t.Obj().Pkg().Path()
}
Expand All @@ -22,6 +32,7 @@ func NewArgument(v *types.Var, variadic bool) Argument {
return Argument{
Name: v.Name(),
Type: v.Type(),
TypeArgs: typeArgs,
Variadic: variadic,
Package: pkg,
}
Expand Down
7 changes: 7 additions & 0 deletions parsing/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@ import (

type Interface struct {
Name string
TypeArgs []*types.TypeParam
Signatures []Signature
}

func NewInterface(n *types.Named) (Interface, error) {
var signatures []Signature

var targs []*types.TypeParam
for i := 0; i < n.TypeParams().Len(); i++ {
targs = append(targs, n.TypeParams().At(i))
}

underlying, ok := n.Underlying().(*types.Interface)
if !ok {
return Interface{}, fmt.Errorf("failed to load underlying type: %s is not an interface", n.Underlying())
Expand All @@ -24,6 +30,7 @@ func NewInterface(n *types.Named) (Interface, error) {

return Interface{
Name: n.Obj().Name(),
TypeArgs: targs,
Signatures: signatures,
}, nil
}
33 changes: 33 additions & 0 deletions parsing/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ var _ = Describe("Interface", func() {
})
})

Context("when the interface has type params", func() {
var typeParam *types.TypeParam

BeforeEach(func() {
signature := types.NewSignature(nil, nil, nil, false)
methods := []*types.Func{
types.NewFunc(0, pkg, "SomeMethod", signature),
}

underlying = types.NewInterfaceType(methods, nil).Complete()
namedType = types.NewNamed(typeName, underlying, nil)

typeName := types.NewTypeName(0, pkg, "T", nil)
constraint := types.NewNamed(types.NewTypeName(0, nil, "any", nil), types.NewInterface(nil, nil), nil)
typeParam = types.NewTypeParam(typeName, constraint)
namedType.SetTypeParams([]*types.TypeParam{typeParam})
})

It("includes those methods in the parsed interface", func() {
iface, err := parsing.NewInterface(namedType)
Expect(err).NotTo(HaveOccurred())
Expect(iface).To(Equal(parsing.Interface{
Name: "SomeType",
Signatures: []parsing.Signature{
{
Name: "SomeMethod",
},
},
TypeArgs: []*types.TypeParam{typeParam},
}))
})
})

Context("when the underlying type is not interface", func() {
BeforeEach(func() {
intType := types.Universe.Lookup("int").Type()
Expand Down
4 changes: 3 additions & 1 deletion parsing/signature.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package parsing

import "go/types"
import (
"go/types"
)

type Signature struct {
Name string
Expand Down
8 changes: 8 additions & 0 deletions parsing/type_param.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package parsing

import "go/types"

type TypeParam struct {
Name string
Constraint types.Type
}
21 changes: 13 additions & 8 deletions rendering/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ func (c *Context) BuildFakeType(iface parsing.Interface) NamedType {
calls = append(calls, c.BuildCallStruct(signature))
}

return NewNamedType(TitleString(iface.Name), NewStruct(calls))
var targs []Type
for _, targ := range iface.TypeArgs {
targs = append(targs, NewType(targ, nil))
}

return NewNamedType(TitleString(iface.Name), NewStruct(calls), targs)
}

func (c *Context) BuildCallStruct(signature parsing.Signature) Field {
Expand All @@ -58,7 +63,7 @@ func (c *Context) BuildCallStruct(signature parsing.Signature) Field {
}

func (c *Context) BuildMutex() Field {
return NewField("mutex", NewNamedType("sync.Mutex", NewStruct(nil)))
return NewField("mutex", NewNamedType("sync.Mutex", NewStruct(nil), nil))
}

func (c *Context) BuildCallCount() Field {
Expand All @@ -74,7 +79,7 @@ func (c *Context) BuildReceives(args []parsing.Argument) Field {
}
name = TitleString(name)

field := NewField(name, NewType(arg.Type))
field := NewField(name, NewType(arg.Type, arg.TypeArgs))
fields = append(fields, field)
}

Expand All @@ -90,7 +95,7 @@ func (c *Context) BuildReturns(args []parsing.Argument) Field {
}
name = TitleString(name)

field := NewField(name, NewType(arg.Type))
field := NewField(name, NewType(arg.Type, arg.TypeArgs))
fields = append(fields, field)
}

Expand Down Expand Up @@ -121,7 +126,7 @@ func (c *Context) BuildParams(args []parsing.Argument, named bool) []Param {
name = ParamName(i)
}

params = append(params, NewParam(name, NewType(arg.Type), arg.Variadic))
params = append(params, NewParam(name, NewType(arg.Type, arg.TypeArgs), arg.Variadic))
}

return params
Expand All @@ -130,7 +135,7 @@ func (c *Context) BuildParams(args []parsing.Argument, named bool) []Param {
func (c *Context) BuildResults(args []parsing.Argument) []Result {
var results []Result
for _, arg := range args {
results = append(results, NewResult(NewType(arg.Type)))
results = append(results, NewResult(NewType(arg.Type, arg.TypeArgs)))
}

return results
Expand All @@ -143,7 +148,7 @@ func (c *Context) BuildBody(receiver Receiver, signature parsing.Signature) []St
c.BuildIncrementStatement(receiver, signature.Name),
}

for i, _ := range signature.Params {
for i := range signature.Params {
statements = append(statements, c.BuildAssignStatement(receiver, signature.Name, i, signature.Params))
}

Expand Down Expand Up @@ -195,7 +200,7 @@ func (c *Context) BuildAssignStatement(receiver Receiver, name string, index int
paramField := receivesField.Type.(Struct).FieldWithName(argName)
selector := NewSelector(receiver, callField, receivesField, paramField)
paramName := ParamName(index)
param := NewParam(paramName, NewType(arg.Type), arg.Variadic)
param := NewParam(paramName, NewType(arg.Type, arg.TypeArgs), arg.Variadic)

return NewAssignStatement(selector, param)
}
Expand Down
70 changes: 52 additions & 18 deletions rendering/named_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,72 @@ import (
)

type NamedType struct {
Name string
Type Type
Name string
Type Type
TypeArgs []Type
}

func NewNamedType(name string, t Type) NamedType {
func NewNamedType(name string, t Type, targTypes []Type) NamedType {
return NamedType{
Name: name,
Type: t,
Name: name,
Type: t,
TypeArgs: targTypes,
}
}

func NewDefinedType(name string) NamedType {
return NamedType{
Name: name,
Type: Interface{},
}
func NewDefinedType(name string, targTypes []Type) NamedType {
return NewNamedType(name, Interface{}, targTypes)
}

func (nt NamedType) Expr() ast.Expr {
return ast.NewIdent(nt.Name)
switch len(nt.TypeArgs) {
case 0:
return ast.NewIdent(nt.Name)

case 1:
return &ast.IndexExpr{
X: ast.NewIdent(nt.Name),
Index: nt.TypeArgs[0].Expr(),
}

default:
var indices []ast.Expr
for _, typeArg := range nt.TypeArgs {
indices = append(indices, typeArg.Expr())
}

return &ast.IndexListExpr{
X: ast.NewIdent(nt.Name),
Indices: indices,
}
}
}

func (nt NamedType) isType() {}

func (nt NamedType) Decl() ast.Decl {
spec := &ast.TypeSpec{
Name: ast.NewIdent(nt.Name),
Type: nt.Type.Expr(),
}

if len(nt.TypeArgs) > 0 {
var fields []*ast.Field
for _, targ := range nt.TypeArgs {
ntarg := targ.(NamedType)
fields = append(fields, &ast.Field{
Names: []*ast.Ident{ast.NewIdent(ntarg.Name)},
Type: ntarg.Type.Expr(),
})
}

spec.TypeParams = &ast.FieldList{
List: fields,
}
}

return &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: ast.NewIdent(nt.Name),
Type: nt.Type.Expr(),
},
},
Tok: token.TYPE,
Specs: []ast.Spec{spec},
}
}
Loading

0 comments on commit dd7d6f7

Please sign in to comment.