diff --git a/mockgen/generic_go118.go b/mockgen/generic_go118.go index b29db9a8..22f66c4d 100644 --- a/mockgen/generic_go118.go +++ b/mockgen/generic_go118.go @@ -86,3 +86,182 @@ func getIdentTypeParams(decl interface{}) string { sb.WriteString("]") return sb.String() } + +func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *ast.Field, pkg string, tps map[string]bool) (wasGeneric bool, err error) { + switch v := field.Type.(type) { + case *ast.IndexExpr, *ast.IndexListExpr: + wasGeneric = true + // generic embedded interface + // may or may not be external pkg + // *ast.IndexExpr for embedded generic iface with single index e.g. DoSomething[T] + // *ast.IndexListExpr for embedded generic iface with multiple indexes e.g. DoSomething[T, K] + var ( + ident *ast.Ident + selIdent *ast.Ident // selector identity only used in external import + // path string + typeParams []model.Type // normalize to slice whether IndexExpr or IndexListExpr to make it consistent to work with + ) + if ie, ok := v.(*ast.IndexExpr); ok { + if se, ok := ie.X.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ident = ie.X.(*ast.Ident) + } + var typParam model.Type + if typParam, err = p.parseType(pkg, ie.Index, tps); err != nil { + return + } + typeParams = append(typeParams, typParam) + } else { + ile := v.(*ast.IndexListExpr) + if se, ok := ile.X.(*ast.SelectorExpr); ok { + ident, selIdent = se.X.(*ast.Ident), se.Sel + } else { + ident = ile.X.(*ast.Ident) + } + var typParam model.Type + for i := range ile.Indices { + if typParam, err = p.parseType(pkg, ile.Indices[i], tps); err != nil { + return + } + typeParams = append(typeParams, typParam) + } + } + + var ( + embeddedIface *model.Interface + ) + + if selIdent == nil { + if embeddedIface, err = p.retrieveEmbeddedIfaceModel(pkg, ident.Name, ident.Pos(), false); err != nil { + return + } + } else { + filePkg, sel := ident.String(), selIdent.String() + if embeddedIface, err = p.retrieveEmbeddedIfaceModel(filePkg, sel, ident.Pos(), true); err != nil { + return + } + } + + // Copy the methods. + // TODO: apply shadowing rules. + for _, m := range embeddedIface.Methods { + // non-trivial part - we have to match up the as-used type params with the as-defined + // defined as DoSomething[T any, K any] + // used as DoSomething[somPkg.SomeType, int64] + // meaning methods may be like in definition: + // Do(T) (K, error) + // but need to be like this in implementation: + // Do(somePkg.SomeType) (int64, error) + gm := m.Clone() // clone so we can change without changing source def + + // overwrite all typed params for incoming/outgoing params + // to get the implementor-specified typing over the definition-specified typing + + for pinIdx, pin := range gm.In { + if genType, hasGeneric := p.getTypedParamForGeneric(pin.Type, embeddedIface, typeParams); hasGeneric { + gm.In[pinIdx].Type = genType + } + } + for outIdx, out := range gm.Out { + if genType, hasGeneric := p.getTypedParamForGeneric(out.Type, embeddedIface, typeParams); hasGeneric { + gm.Out[outIdx].Type = genType + } + } + if gm.Variadic != nil { + if vGenType, hasGeneric := p.getTypedParamForGeneric(gm.Variadic.Type, embeddedIface, typeParams); hasGeneric { + gm.Variadic.Type = vGenType + } + } + + iface.AddMethod(gm) + } + } + + return +} + +// getTypedParamForGeneric is recursive func to hydrate all generic types within a model.Type +// so they get populated instead with the actual desired target types +func (p *fileParser) getTypedParamForGeneric(t model.Type, iface *model.Interface, knownTypeParams []model.Type) (model.Type, bool) { + switch typ := t.(type) { + case *model.ArrayType: + if gType, wasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); wasGeneric { + typ.Type = gType + return typ, true + } + case *model.ChanType: + if gType, wasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); wasGeneric { + typ.Type = gType + return typ, true + } + case *model.FuncType: + hasGeneric := false + for inIdx, inParam := range typ.In { + if genType, ok := p.getTypedParamForGeneric(inParam.Type, iface, knownTypeParams); ok { + hasGeneric = true + typ.In[inIdx].Type = genType + } + } + for outIdx, outParam := range typ.Out { + if genType, ok := p.getTypedParamForGeneric(outParam.Type, iface, knownTypeParams); ok { + hasGeneric = true + typ.Out[outIdx].Type = genType + } + } + if typ.Variadic != nil { + if genType, ok := p.getTypedParamForGeneric(typ.Variadic.Type, iface, knownTypeParams); ok { + hasGeneric = true + typ.Variadic.Type = genType + } + } + if hasGeneric { + return typ, true + } + case *model.MapType: + var ( + keyTyp, valTyp model.Type + wasKeyGeneric, wasValGeneric bool + ) + if keyTyp, wasKeyGeneric = p.getTypedParamForGeneric(typ.Key, iface, knownTypeParams); wasKeyGeneric { + typ.Key = keyTyp + } + if valTyp, wasValGeneric = p.getTypedParamForGeneric(typ.Value, iface, knownTypeParams); wasValGeneric { + typ.Value = valTyp + } + if wasKeyGeneric || wasValGeneric { + return typ, true + } + case *model.NamedType: + if typ.TypeParams == nil { + return nil, false + } + hasGeneric := false + for i, tp := range typ.TypeParams.TypeParameters { + // it will either be a type with name matching a generic parameter + // or it will be something like ptr or slice etc... + if srcParamIdx := iface.TypeParamIndexByName(tp.String(nil, "")); srcParamIdx > -1 && srcParamIdx < len(knownTypeParams) { + hasGeneric = true + dstParamTyp := knownTypeParams[srcParamIdx] + typ.TypeParams.TypeParameters[i] = dstParamTyp + } else if _, ok := p.getTypedParamForGeneric(tp, iface, knownTypeParams); ok { + hasGeneric = true + } + } + if hasGeneric { + return typ, true + } + case model.PredeclaredType: + if srcParamIdx := iface.TypeParamIndexByName(typ.String(nil, "")); srcParamIdx > -1 { + dstParamTyp := knownTypeParams[srcParamIdx] + return dstParamTyp, true + } + case *model.PointerType: + if gType, hasGeneric := p.getTypedParamForGeneric(typ.Type, iface, knownTypeParams); hasGeneric { + typ.Type = gType + return typ, true + } + } + + return nil, false +} diff --git a/mockgen/generic_notgo118.go b/mockgen/generic_notgo118.go index 8fe48c17..87c8d715 100644 --- a/mockgen/generic_notgo118.go +++ b/mockgen/generic_notgo118.go @@ -34,3 +34,7 @@ func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]b func getIdentTypeParams(decl interface{}) string { return "" } + +func (p *fileParser) parseEmbeddedGenericIface(iface *model.Interface, field *ast.Field, pkg string, tps map[string]bool) (wasGeneric bool, err error) { + return false, nil +} diff --git a/mockgen/internal/tests/generics/generics.go b/mockgen/internal/tests/generics/generics.go index 0b389622..93df22d5 100644 --- a/mockgen/internal/tests/generics/generics.go +++ b/mockgen/internal/tests/generics/generics.go @@ -25,6 +25,8 @@ type Bar[T any, R any] interface { Seventeen() (*Foo[other.Three, other.Four], error) Eighteen() (Iface[*other.Five], error) Nineteen() AliasType + Twenty(*other.One[T]) *other.Two[T, R] + TwentyOne(*string) *other.Two[*T, *R] } type Foo[T any, R any] struct{} @@ -38,3 +40,9 @@ type StructType struct{} type StructType2 struct{} type AliasType Baz[other.Three] + +type EmbeddingIface interface { + Bar[other.Three, error] + other.Otherer[StructType, other.Five] + LocalFunc() error +} diff --git a/mockgen/internal/tests/generics/other/other.go b/mockgen/internal/tests/generics/other/other.go index 9265422b..39db6c23 100644 --- a/mockgen/internal/tests/generics/other/other.go +++ b/mockgen/internal/tests/generics/other/other.go @@ -9,3 +9,14 @@ type Three struct{} type Four struct{} type Five interface{} + +type Otherer[T any, R any] interface { + DoT(T) error + DoR(R) error + MakeThem(...T) (R, error) + GetThem() ([]T, error) + GetThemPtr() ([]*T, error) + GetThemMapped() ([]map[int64]*T, error) + GetMap() (map[bool]T, error) + AddChan(chan T) error +} diff --git a/mockgen/internal/tests/generics/source/mock_generics_test.go b/mockgen/internal/tests/generics/source/mock_generics_test.go index 0223e311..de6e1779 100644 --- a/mockgen/internal/tests/generics/source/mock_generics_test.go +++ b/mockgen/internal/tests/generics/source/mock_generics_test.go @@ -291,6 +291,34 @@ func (mr *MockBarMockRecorder[T, R]) Twelve() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twelve", reflect.TypeOf((*MockBar[T, R])(nil).Twelve)) } +// Twenty mocks base method. +func (m *MockBar[T, R]) Twenty(arg0 *other.One[T]) *other.Two[T, R] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Twenty", arg0) + ret0, _ := ret[0].(*other.Two[T, R]) + return ret0 +} + +// Twenty indicates an expected call of Twenty. +func (mr *MockBarMockRecorder[T, R]) Twenty(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twenty", reflect.TypeOf((*MockBar[T, R])(nil).Twenty), arg0) +} + +// TwentyOne mocks base method. +func (m *MockBar[T, R]) TwentyOne(arg0 *string) *other.Two[*T, *R] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TwentyOne", arg0) + ret0, _ := ret[0].(*other.Two[*T, *R]) + return ret0 +} + +// TwentyOne indicates an expected call of TwentyOne. +func (mr *MockBarMockRecorder[T, R]) TwentyOne(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TwentyOne", reflect.TypeOf((*MockBar[T, R])(nil).TwentyOne), arg0) +} + // Two mocks base method. func (m *MockBar[T, R]) Two(arg0 T) string { m.ctrl.T.Helper() @@ -327,3 +355,459 @@ func NewMockIface[T any](ctrl *gomock.Controller) *MockIface[T] { func (m *MockIface[T]) EXPECT() *MockIfaceMockRecorder[T] { return m.recorder } + +// MockEmbeddingIface is a mock of EmbeddingIface interface. +type MockEmbeddingIface struct { + ctrl *gomock.Controller + recorder *MockEmbeddingIfaceMockRecorder +} + +// MockEmbeddingIfaceMockRecorder is the mock recorder for MockEmbeddingIface. +type MockEmbeddingIfaceMockRecorder struct { + mock *MockEmbeddingIface +} + +// NewMockEmbeddingIface creates a new mock instance. +func NewMockEmbeddingIface(ctrl *gomock.Controller) *MockEmbeddingIface { + mock := &MockEmbeddingIface{ctrl: ctrl} + mock.recorder = &MockEmbeddingIfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEmbeddingIface) EXPECT() *MockEmbeddingIfaceMockRecorder { + return m.recorder +} + +// AddChan mocks base method. +func (m *MockEmbeddingIface) AddChan(arg0 chan generics.StructType) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddChan", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddChan indicates an expected call of AddChan. +func (mr *MockEmbeddingIfaceMockRecorder) AddChan(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddChan", reflect.TypeOf((*MockEmbeddingIface)(nil).AddChan), arg0) +} + +// DoR mocks base method. +func (m *MockEmbeddingIface) DoR(arg0 other.Five) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoR", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DoR indicates an expected call of DoR. +func (mr *MockEmbeddingIfaceMockRecorder) DoR(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoR", reflect.TypeOf((*MockEmbeddingIface)(nil).DoR), arg0) +} + +// DoT mocks base method. +func (m *MockEmbeddingIface) DoT(arg0 generics.StructType) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoT", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// DoT indicates an expected call of DoT. +func (mr *MockEmbeddingIfaceMockRecorder) DoT(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoT", reflect.TypeOf((*MockEmbeddingIface)(nil).DoT), arg0) +} + +// Eight mocks base method. +func (m *MockEmbeddingIface) Eight(arg0 other.Three) other.Two[other.Three, error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Eight", arg0) + ret0, _ := ret[0].(other.Two[other.Three, error]) + return ret0 +} + +// Eight indicates an expected call of Eight. +func (mr *MockEmbeddingIfaceMockRecorder) Eight(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eight", reflect.TypeOf((*MockEmbeddingIface)(nil).Eight), arg0) +} + +// Eighteen mocks base method. +func (m *MockEmbeddingIface) Eighteen() (generics.Iface[*other.Five], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Eighteen") + ret0, _ := ret[0].(generics.Iface[*other.Five]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Eighteen indicates an expected call of Eighteen. +func (mr *MockEmbeddingIfaceMockRecorder) Eighteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eighteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Eighteen)) +} + +// Eleven mocks base method. +func (m *MockEmbeddingIface) Eleven() (*other.One[other.Three], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Eleven") + ret0, _ := ret[0].(*other.One[other.Three]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Eleven indicates an expected call of Eleven. +func (mr *MockEmbeddingIfaceMockRecorder) Eleven() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Eleven", reflect.TypeOf((*MockEmbeddingIface)(nil).Eleven)) +} + +// Fifteen mocks base method. +func (m *MockEmbeddingIface) Fifteen() (generics.Iface[generics.StructType], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Fifteen") + ret0, _ := ret[0].(generics.Iface[generics.StructType]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Fifteen indicates an expected call of Fifteen. +func (mr *MockEmbeddingIfaceMockRecorder) Fifteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fifteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Fifteen)) +} + +// Five mocks base method. +func (m *MockEmbeddingIface) Five(arg0 other.Three) generics.Baz[other.Three] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Five", arg0) + ret0, _ := ret[0].(generics.Baz[other.Three]) + return ret0 +} + +// Five indicates an expected call of Five. +func (mr *MockEmbeddingIfaceMockRecorder) Five(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Five", reflect.TypeOf((*MockEmbeddingIface)(nil).Five), arg0) +} + +// Four mocks base method. +func (m *MockEmbeddingIface) Four(arg0 other.Three) generics.Foo[other.Three, error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Four", arg0) + ret0, _ := ret[0].(generics.Foo[other.Three, error]) + return ret0 +} + +// Four indicates an expected call of Four. +func (mr *MockEmbeddingIfaceMockRecorder) Four(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Four", reflect.TypeOf((*MockEmbeddingIface)(nil).Four), arg0) +} + +// Fourteen mocks base method. +func (m *MockEmbeddingIface) Fourteen() (*generics.Foo[generics.StructType, generics.StructType2], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Fourteen") + ret0, _ := ret[0].(*generics.Foo[generics.StructType, generics.StructType2]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Fourteen indicates an expected call of Fourteen. +func (mr *MockEmbeddingIfaceMockRecorder) Fourteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fourteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Fourteen)) +} + +// GetMap mocks base method. +func (m *MockEmbeddingIface) GetMap() (map[bool]generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMap") + ret0, _ := ret[0].(map[bool]generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMap indicates an expected call of GetMap. +func (mr *MockEmbeddingIfaceMockRecorder) GetMap() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMap", reflect.TypeOf((*MockEmbeddingIface)(nil).GetMap)) +} + +// GetThem mocks base method. +func (m *MockEmbeddingIface) GetThem() ([]generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetThem") + ret0, _ := ret[0].([]generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetThem indicates an expected call of GetThem. +func (mr *MockEmbeddingIfaceMockRecorder) GetThem() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThem", reflect.TypeOf((*MockEmbeddingIface)(nil).GetThem)) +} + +// GetThemMapped mocks base method. +func (m *MockEmbeddingIface) GetThemMapped() ([]map[int64]*generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetThemMapped") + ret0, _ := ret[0].([]map[int64]*generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetThemMapped indicates an expected call of GetThemMapped. +func (mr *MockEmbeddingIfaceMockRecorder) GetThemMapped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThemMapped", reflect.TypeOf((*MockEmbeddingIface)(nil).GetThemMapped)) +} + +// GetThemPtr mocks base method. +func (m *MockEmbeddingIface) GetThemPtr() ([]*generics.StructType, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetThemPtr") + ret0, _ := ret[0].([]*generics.StructType) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetThemPtr indicates an expected call of GetThemPtr. +func (mr *MockEmbeddingIfaceMockRecorder) GetThemPtr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetThemPtr", reflect.TypeOf((*MockEmbeddingIface)(nil).GetThemPtr)) +} + +// LocalFunc mocks base method. +func (m *MockEmbeddingIface) LocalFunc() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalFunc") + ret0, _ := ret[0].(error) + return ret0 +} + +// LocalFunc indicates an expected call of LocalFunc. +func (mr *MockEmbeddingIfaceMockRecorder) LocalFunc() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalFunc", reflect.TypeOf((*MockEmbeddingIface)(nil).LocalFunc)) +} + +// MakeThem mocks base method. +func (m *MockEmbeddingIface) MakeThem(arg0 ...generics.StructType) (other.Five, error) { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "MakeThem", varargs...) + ret0, _ := ret[0].(other.Five) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MakeThem indicates an expected call of MakeThem. +func (mr *MockEmbeddingIfaceMockRecorder) MakeThem(arg0 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeThem", reflect.TypeOf((*MockEmbeddingIface)(nil).MakeThem), arg0...) +} + +// Nine mocks base method. +func (m *MockEmbeddingIface) Nine(arg0 generics.Iface[other.Three]) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Nine", arg0) +} + +// Nine indicates an expected call of Nine. +func (mr *MockEmbeddingIfaceMockRecorder) Nine(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nine", reflect.TypeOf((*MockEmbeddingIface)(nil).Nine), arg0) +} + +// Nineteen mocks base method. +func (m *MockEmbeddingIface) Nineteen() generics.AliasType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Nineteen") + ret0, _ := ret[0].(generics.AliasType) + return ret0 +} + +// Nineteen indicates an expected call of Nineteen. +func (mr *MockEmbeddingIfaceMockRecorder) Nineteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nineteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Nineteen)) +} + +// One mocks base method. +func (m *MockEmbeddingIface) One(arg0 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "One", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// One indicates an expected call of One. +func (mr *MockEmbeddingIfaceMockRecorder) One(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "One", reflect.TypeOf((*MockEmbeddingIface)(nil).One), arg0) +} + +// Seven mocks base method. +func (m *MockEmbeddingIface) Seven(arg0 other.Three) other.One[other.Three] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seven", arg0) + ret0, _ := ret[0].(other.One[other.Three]) + return ret0 +} + +// Seven indicates an expected call of Seven. +func (mr *MockEmbeddingIfaceMockRecorder) Seven(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seven", reflect.TypeOf((*MockEmbeddingIface)(nil).Seven), arg0) +} + +// Seventeen mocks base method. +func (m *MockEmbeddingIface) Seventeen() (*generics.Foo[other.Three, other.Four], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seventeen") + ret0, _ := ret[0].(*generics.Foo[other.Three, other.Four]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Seventeen indicates an expected call of Seventeen. +func (mr *MockEmbeddingIfaceMockRecorder) Seventeen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seventeen", reflect.TypeOf((*MockEmbeddingIface)(nil).Seventeen)) +} + +// Six mocks base method. +func (m *MockEmbeddingIface) Six(arg0 other.Three) *generics.Baz[other.Three] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Six", arg0) + ret0, _ := ret[0].(*generics.Baz[other.Three]) + return ret0 +} + +// Six indicates an expected call of Six. +func (mr *MockEmbeddingIfaceMockRecorder) Six(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Six", reflect.TypeOf((*MockEmbeddingIface)(nil).Six), arg0) +} + +// Sixteen mocks base method. +func (m *MockEmbeddingIface) Sixteen() (generics.Baz[other.Three], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sixteen") + ret0, _ := ret[0].(generics.Baz[other.Three]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sixteen indicates an expected call of Sixteen. +func (mr *MockEmbeddingIfaceMockRecorder) Sixteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sixteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Sixteen)) +} + +// Ten mocks base method. +func (m *MockEmbeddingIface) Ten(arg0 *other.Three) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Ten", arg0) +} + +// Ten indicates an expected call of Ten. +func (mr *MockEmbeddingIfaceMockRecorder) Ten(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ten", reflect.TypeOf((*MockEmbeddingIface)(nil).Ten), arg0) +} + +// Thirteen mocks base method. +func (m *MockEmbeddingIface) Thirteen() (generics.Baz[generics.StructType], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Thirteen") + ret0, _ := ret[0].(generics.Baz[generics.StructType]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Thirteen indicates an expected call of Thirteen. +func (mr *MockEmbeddingIfaceMockRecorder) Thirteen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Thirteen", reflect.TypeOf((*MockEmbeddingIface)(nil).Thirteen)) +} + +// Three mocks base method. +func (m *MockEmbeddingIface) Three(arg0 other.Three) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Three", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Three indicates an expected call of Three. +func (mr *MockEmbeddingIfaceMockRecorder) Three(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Three", reflect.TypeOf((*MockEmbeddingIface)(nil).Three), arg0) +} + +// Twelve mocks base method. +func (m *MockEmbeddingIface) Twelve() (*other.Two[other.Three, error], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Twelve") + ret0, _ := ret[0].(*other.Two[other.Three, error]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Twelve indicates an expected call of Twelve. +func (mr *MockEmbeddingIfaceMockRecorder) Twelve() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twelve", reflect.TypeOf((*MockEmbeddingIface)(nil).Twelve)) +} + +// Twenty mocks base method. +func (m *MockEmbeddingIface) Twenty(arg0 *other.One[other.Three]) *other.Two[other.Three, error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Twenty", arg0) + ret0, _ := ret[0].(*other.Two[other.Three, error]) + return ret0 +} + +// Twenty indicates an expected call of Twenty. +func (mr *MockEmbeddingIfaceMockRecorder) Twenty(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Twenty", reflect.TypeOf((*MockEmbeddingIface)(nil).Twenty), arg0) +} + +// TwentyOne mocks base method. +func (m *MockEmbeddingIface) TwentyOne(arg0 *string) *other.Two[*other.Three, *error] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TwentyOne", arg0) + ret0, _ := ret[0].(*other.Two[*other.Three, *error]) + return ret0 +} + +// TwentyOne indicates an expected call of TwentyOne. +func (mr *MockEmbeddingIfaceMockRecorder) TwentyOne(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TwentyOne", reflect.TypeOf((*MockEmbeddingIface)(nil).TwentyOne), arg0) +} + +// Two mocks base method. +func (m *MockEmbeddingIface) Two(arg0 other.Three) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Two", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// Two indicates an expected call of Two. +func (mr *MockEmbeddingIfaceMockRecorder) Two(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Two", reflect.TypeOf((*MockEmbeddingIface)(nil).Two), arg0) +} diff --git a/mockgen/model/model.go b/mockgen/model/model.go index 94d7f4ba..5ed9462c 100644 --- a/mockgen/model/model.go +++ b/mockgen/model/model.go @@ -61,6 +61,38 @@ type Interface struct { TypeParams []*Parameter } +// TypeParamIndexByName returns the index of the type parameter matching on name. If none matching, returns -1. +// +// This is especially useful for generics where interface is something like this: +// Doer[T any, K any]{ +// Start(T) +// Add(K) error +// Stop() []K +// } +// +// But it is used like this: +// [ T , K ] +// type MyDoer = Doer[types.SomeType, otherPkg.SomeOtherThing] +// or as an embedded interface: +// type MyDoer interface { +// [ T , K ] +// Doer[types.SomeType, otherPkg.SomeOtherThing] +// } +// +// If parsing the Add method for an implementation of this interface, +// we need to be able to swap out K for whatever the actual type is. +// K will be at index 1 of the interface's TypeParams, +// but it will also be at index 1 of the actual type params in either the +// definition of the interface being mocked or the generic interface it embeds. +func (intf *Interface) TypeParamIndexByName(name string) int { + for i, p := range intf.TypeParams { + if p.Name == name { + return i + } + } + return -1 +} + // Print writes the interface name and its methods. func (intf *Interface) Print(w io.Writer) { _, _ = fmt.Fprintf(w, "interface %s\n", intf.Name) @@ -92,6 +124,28 @@ type Method struct { Variadic *Parameter // may be nil } +// Clone makes a deep clone of a Method. +// +// This is useful specifically for generics so that generic parameters +// from source interface methods (e.g. Iface[T any, R any]) +// can be swapped out with actualized types from a referencing entity +// (e.g. type OtherIface = Iface[external.Foo, Baz]). +func (m *Method) Clone() *Method { + mm := &Method{ + Name: m.Name, + In: make([]*Parameter, 0), + Out: make([]*Parameter, 0), + Variadic: m.Variadic.clone(), + } + for _, in := range m.In { + mm.In = append(mm.In, in.clone()) + } + for _, out := range m.Out { + mm.Out = append(mm.Out, out.clone()) + } + return mm +} + // Print writes the method name and its signature. func (m *Method) Print(w io.Writer) { _, _ = fmt.Fprintf(w, " - method %s\n", m.Name) @@ -131,6 +185,16 @@ type Parameter struct { Type Type } +func (p *Parameter) clone() *Parameter { + if p == nil { + return nil + } + return &Parameter{ + Name: p.Name, + Type: p.Type.clone(), + } +} + // Print writes a method parameter. func (p *Parameter) Print(w io.Writer) { n := p.Name @@ -144,6 +208,7 @@ func (p *Parameter) Print(w io.Writer) { type Type interface { String(pm map[string]string, pkgOverride string) string addImports(im map[string]bool) + clone() Type } func init() { @@ -180,6 +245,13 @@ func (at *ArrayType) String(pm map[string]string, pkgOverride string) string { func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) } +func (at *ArrayType) clone() Type { + return &ArrayType{ + Len: at.Len, + Type: at.Type.clone(), + } +} + // ChanType is a channel type. type ChanType struct { Dir ChanDir // 0, 1 or 2 @@ -199,6 +271,13 @@ func (ct *ChanType) String(pm map[string]string, pkgOverride string) string { func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) } +func (ct *ChanType) clone() Type { + return &ChanType{ + Dir: ct.Dir, + Type: ct.Type.clone(), + } +} + // ChanDir is a channel direction. type ChanDir int @@ -247,6 +326,21 @@ func (ft *FuncType) addImports(im map[string]bool) { } } +func (ft *FuncType) clone() Type { + ftt := &FuncType{ + In: make([]*Parameter, 0), + Out: make([]*Parameter, 0), + Variadic: ft.Variadic.clone(), + } + for _, in := range ft.In { + ftt.In = append(ftt.In, in.clone()) + } + for _, out := range ft.Out { + ftt.Out = append(ftt.Out, out.clone()) + } + return ftt +} + // MapType is a map type. type MapType struct { Key, Value Type @@ -261,6 +355,13 @@ func (mt *MapType) addImports(im map[string]bool) { mt.Value.addImports(im) } +func (mt *MapType) clone() Type { + return &MapType{ + Key: mt.Key, + Value: mt.Value.clone(), + } +} + // NamedType is an exported type in a package. type NamedType struct { Package string // may be empty @@ -287,6 +388,21 @@ func (nt *NamedType) addImports(im map[string]bool) { nt.TypeParams.addImports(im) } +func (nt *NamedType) clone() Type { + if nt == nil { + return nil + } + + ntt := &NamedType{ + Package: nt.Package, + Type: nt.Type, + } + if nt.TypeParams != nil { + ntt.TypeParams = nt.TypeParams.clone().(*TypeParametersType) + } + return ntt +} + // PointerType is a pointer to another type. type PointerType struct { Type Type @@ -297,11 +413,16 @@ func (pt *PointerType) String(pm map[string]string, pkgOverride string) string { } func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) } +func (pt *PointerType) clone() Type { + return &PointerType{Type: pt.Type.clone()} +} + // PredeclaredType is a predeclared type such as "int". type PredeclaredType string func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) } func (pt PredeclaredType) addImports(map[string]bool) {} +func (pt PredeclaredType) clone() Type { return PredeclaredType(pt) } // TypeParametersType contains type paramters for a NamedType. type TypeParametersType struct { @@ -333,6 +454,17 @@ func (tp *TypeParametersType) addImports(im map[string]bool) { } } +func (tp *TypeParametersType) clone() Type { + if tp == nil { + return nil + } + tpt := &TypeParametersType{} + for _, t := range tp.TypeParameters { + tpt.TypeParameters = append(tpt.TypeParameters, t.clone()) + } + return tpt +} + // The following code is intended to be called by the program generated by ../reflect.go. // InterfaceFromInterfaceType returns a pointer to an interface for the diff --git a/mockgen/parse.go b/mockgen/parse.go index 21c0d70a..3469029a 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -304,37 +304,9 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode iface.AddMethod(m) case *ast.Ident: // Embedded interface in this package. - embeddedIfaceType := p.auxInterfaces.Get(pkg, v.String()) - if embeddedIfaceType == nil { - embeddedIfaceType = p.importedInterfaces.Get(pkg, v.String()) - } - - var embeddedIface *model.Interface - if embeddedIfaceType != nil { - var err error - embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } else { - // This is built-in error interface. - if v.String() == model.ErrorInterface.Name { - embeddedIface = &model.ErrorInterface - } else { - ip, err := p.parsePackage(pkg) - if err != nil { - return nil, p.errorf(v.Pos(), "could not parse package %s: %v", pkg, err) - } - - if embeddedIfaceType = ip.importedInterfaces.Get(pkg, v.String()); embeddedIfaceType == nil { - return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", pkg, v.String()) - } - - embeddedIface, err = ip.parseInterface(v.String(), pkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } + embeddedIface, err := p.retrieveEmbeddedIfaceModel(pkg, v.String(), v.Pos(), false) + if err != nil { + return nil, err } // Copy the methods. for _, m := range embeddedIface.Methods { @@ -343,40 +315,9 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode case *ast.SelectorExpr: // Embedded interface in another package. filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() - embeddedPkg, ok := p.imports[filePkg] - if !ok { - return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) - } - - var embeddedIface *model.Interface - var err error - embeddedIfaceType := p.auxInterfaces.Get(filePkg, sel) - if embeddedIfaceType != nil { - embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) - if err != nil { - return nil, err - } - } else { - path := embeddedPkg.Path() - parser := embeddedPkg.Parser() - if parser == nil { - ip, err := p.parsePackage(path) - if err != nil { - return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) - } - parser = ip - p.imports[filePkg] = importedPkg{ - path: embeddedPkg.Path(), - parser: parser, - } - } - if embeddedIfaceType = parser.importedInterfaces.Get(path, sel); embeddedIfaceType == nil { - return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) - } - embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) - if err != nil { - return nil, err - } + embeddedIface, err := p.retrieveEmbeddedIfaceModel(filePkg, sel, v.X.Pos(), true) + if err != nil { + return nil, err } // Copy the methods. // TODO: apply shadowing rules. @@ -384,12 +325,90 @@ func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*mode iface.AddMethod(m) } default: + if wasEmbeddedGeneric, err := p.parseEmbeddedGenericIface(iface, field, pkg, tps); wasEmbeddedGeneric { + if err != nil { + return nil, err + } + continue + } return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) } } return iface, nil } +func (p *fileParser) retrieveEmbeddedIfaceModel(pkg, ifaceName string, pos token.Pos, isImport bool) (m *model.Interface, err error) { + var ( + typ *namedInterface + importPkg importedPackage + ) + + if isImport { + var ok bool + if importPkg, ok = p.imports[pkg]; !ok { + err = p.errorf(pos, "unknown package %s", pkg) + return + } + } + + typ = p.auxInterfaces.Get(pkg, ifaceName) + if typ == nil { + typ = p.importedInterfaces.Get(pkg, ifaceName) + } + if typ != nil { + m, err = p.parseInterface(ifaceName, pkg, typ) + return + } + if ifaceName == model.ErrorInterface.Name { + // built-in error interface + m = &model.ErrorInterface + return + } + // parse from pkg (may be current pkg, may be imported pkg) + // so need to get the proper parser for the pkg + var ifaceParser *fileParser + + if importPkg != nil { + // imported pkg + if ifaceParser = importPkg.Parser(); ifaceParser == nil { + path := importPkg.Path() + if ifaceParser, err = p.parsePackage(path); err != nil { + err = p.errorf(pos, "could not parse package %s: %v", path, err) + return + } + p.imports[pkg] = importedPkg{ + path: importPkg.Path(), + parser: ifaceParser, + } + } + typ = ifaceParser.importedInterfaces.Get(importPkg.Path(), ifaceName) + } + + if ifaceParser == nil { + // this pkg + if ifaceParser, err = p.parsePackage(pkg); err != nil { + err = p.errorf(pos, "could not parse package %s: %v", pkg, err) + return + } + typ = ifaceParser.importedInterfaces.Get(pkg, ifaceName) + } + + if typ == nil { + err = p.errorf(pos, "unknown embedded interface %s.%s", pkg, ifaceName) + return + } + + if importPkg != nil { + pkg = importPkg.Path() + } + + // at this point, whether iface is of imported pkg or same pkg, + // the ifaceParser is appropriate and knows how to parse the iface + m, err = ifaceParser.parseInterface(ifaceName, pkg, typ) + + return +} + func (p *fileParser) parseFunc(pkg string, f *ast.FuncType, tps map[string]bool) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { if f.Params != nil { regParams := f.Params.List