Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

gomock generics support enhancement #663

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions mockgen/generic_go118.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,182 @@ func getIdentTypeParams(decl interface{}) string {
sb.WriteString("]")
return sb.String()
}

func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *ast.Field, pkg string, tps map[string]bool) (wasGeneric bool, err error) {
switch v := field.Type.(type) {
case *ast.IndexExpr, *ast.IndexListExpr:
wasGeneric = true
// generic embedded interface
// may or may not be external pkg
// *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T]
// *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K]
var (
ident *ast.Ident
selIdent *ast.Ident // selector identity only used in external import
// path string
typeParams []model.Type // normalize to slice whether IndexExpr or IndexListExpr to make it consistent to work with
)
if ie, ok := v.(*ast.IndexExpr); ok {
if se, ok := ie.X.(*ast.SelectorExpr); ok {
ident, selIdent = se.X.(*ast.Ident), se.Sel
} else {
ident = ie.X.(*ast.Ident)
}
var typParam model.Type
if typParam, err = p.parseType(pkg, ie.Index, tps); err != nil {
return
}
typeParams = append(typeParams, typParam)
} else {
ile := v.(*ast.IndexListExpr)
if se, ok := ile.X.(*ast.SelectorExpr); ok {
ident, selIdent = se.X.(*ast.Ident), se.Sel
} else {
ident = ile.X.(*ast.Ident)
}
var typParam model.Type
for i := range ile.Indices {
if typParam, err = p.parseType(pkg, ile.Indices[i], tps); err != nil {
return
}
typeParams = append(typeParams, typParam)
}
}

var (
embeddedIface *model.Interface
)

if selIdent == nil {
if embeddedIface, err = p.retrieveEmbeddedIfaceModel(pkg, ident.Name, ident.Pos(), false); err != nil {
return
}
} else {
filePkg, sel := ident.String(), selIdent.String()
if embeddedIface, err = p.retrieveEmbeddedIfaceModel(filePkg, sel, ident.Pos(), true); err != nil {
return
}
}

// Copy the methods.
// TODO: apply shadowing rules.
for _, m := range embeddedIface.Methods {
// non-trivial part - we have to match up the as-used type params with the as-defined
// defined as DoSomething[T any, K any]
// used as DoSomething[somPkg.SomeType, int64]
// meaning methods may be like in definition:
// Do(T) (K, error)
// but need to be like this in implementation:
// Do(somePkg.SomeType) (int64, error)
gm := m.Clone() // clone so we can change without changing source def

// overwrite all typed params for incoming/outgoing params
// to get the implementor-specified typing over the definition-specified typing

for pinIdx, pin := range gm.In {
if genType, hasGeneric := p.getTypedParamForGeneric(pin.Type, embeddedIface, typeParams); hasGeneric {
gm.In[pinIdx].Type = genType
}
}
for outIdx, out := range gm.Out {
if genType, hasGeneric := p.getTypedParamForGeneric(out.Type, embeddedIface, typeParams); hasGeneric {
gm.Out[outIdx].Type = genType
}
}
if gm.Variadic != nil {
if vGenType, hasGeneric := p.getTypedParamForGeneric(gm.Variadic.Type, embeddedIface, typeParams); hasGeneric {
gm.Variadic.Type = vGenType
}
}

iface.AddMethod(gm)
}
}

return
}

// getTypedParamForGeneric is recursive func to hydrate all generic types within a model.Type
// so they get populated instead with the actual desired target types
func (p *fileParser) getTypedParamForGeneric(t model.Type, iface *model.Interface, knownTypeParams []model.Type) (model.Type, bool) {
switch typ := t.(type) {
case *model.ArrayType:
if gType, wasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); wasGeneric {
typ.Type = gType
return typ, true
}
case *model.ChanType:
if gType, wasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); wasGeneric {
typ.Type = gType
return typ, true
}
case *model.FuncType:
hasGeneric := false
for inIdx, inParam := range typ.In {
if genType, ok := p.getTypedParamForGeneric(inParam.Type, iface, knownTypeParams); ok {
hasGeneric = true
typ.In[inIdx].Type = genType
}
}
for outIdx, outParam := range typ.Out {
if genType, ok := p.getTypedParamForGeneric(outParam.Type, iface, knownTypeParams); ok {
hasGeneric = true
typ.Out[outIdx].Type = genType
}
}
if typ.Variadic != nil {
if genType, ok := p.getTypedParamForGeneric(typ.Variadic.Type, iface, knownTypeParams); ok {
hasGeneric = true
typ.Variadic.Type = genType
}
}
if hasGeneric {
return typ, true
}
case *model.MapType:
var (
keyTyp, valTyp model.Type
wasKeyGeneric, wasValGeneric bool
)
if keyTyp, wasKeyGeneric = p.getTypedParamForGeneric(typ.Key, iface, knownTypeParams); wasKeyGeneric {
typ.Key = keyTyp
}
if valTyp, wasValGeneric = p.getTypedParamForGeneric(typ.Value, iface, knownTypeParams); wasValGeneric {
typ.Value = valTyp
}
if wasKeyGeneric || wasValGeneric {
return typ, true
}
case *model.NamedType:
if typ.TypeParams == nil {
return nil, false
}
hasGeneric := false
for i, tp := range typ.TypeParams.TypeParameters {
// it will either be a type with name matching a generic parameter
// or it will be something like ptr or slice etc...
if srcParamIdx := iface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(knownTypeParams) {
hasGeneric = true
dstParamTyp := knownTypeParams[srcParamIdx]
typ.TypeParams.TypeParameters[i] = dstParamTyp
} else if _, ok := p.getTypedParamForGeneric(tp, iface, knownTypeParams); ok {
hasGeneric = true
}
}
if hasGeneric {
return typ, true
}
case model.PredeclaredType:
if srcParamIdx := iface.TypeParamIndexByName(typ.String(nil, "")); srcParamIdx > -1 {
dstParamTyp := knownTypeParams[srcParamIdx]
return dstParamTyp, true
}
case *model.PointerType:
if gType, hasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); hasGeneric {
typ.Type = gType
return typ, true
}
}

return nil, false
}
4 changes: 4 additions & 0 deletions mockgen/generic_notgo118.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]b
func getIdentTypeParams(decl interface{}) string {
return ""
}

func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *ast.Field, pkg string, tps map[string]bool) (wasGeneric bool, err error) {
return false, nil
}
8 changes: 8 additions & 0 deletions mockgen/internal/tests/generics/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Bar[T any, R any] interface {
Seventeen() (*Foo[other.Three, other.Four], error)
Eighteen() (Iface[*other.Five], error)
Nineteen() AliasType
Twenty(*other.One[T]) *other.Two[T, R]
TwentyOne(*string) *other.Two[*T, *R]
}

type Foo[T any, R any] struct{}
Expand All @@ -38,3 +40,9 @@ type StructType struct{}
type StructType2 struct{}

type AliasType Baz[other.Three]

type EmbeddingIface interface {
Bar[other.Three, error]
other.Otherer[StructType, other.Five]
LocalFunc() error
}
11 changes: 11 additions & 0 deletions mockgen/internal/tests/generics/other/other.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,14 @@ type Three struct{}
type Four struct{}

type Five interface{}

type Otherer[T any, R any] interface {
DoT(T) error
DoR(R) error
MakeThem(...T) (R, error)
GetThem() ([]T, error)
GetThemPtr() ([]*T, error)
GetThemMapped() ([]map[int64]*T, error)
GetMap() (map[bool]T, error)
AddChan(chan T) error
}
Loading