Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add interfaces flag with unit test #200

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"errors"
"flag"
"fmt"
"go/ast"
"go/build"
Expand Down Expand Up @@ -92,6 +93,20 @@ func sourceMode(source string) (*model.Package, error) {
for pkgPath := range dotImports {
pkg.DotImports = append(pkg.DotImports, pkgPath)
}

// Get positional arguments after the flags
ifaces := flag.Args()

// If there are interfaces provided as positional arguments, filter them
if len(ifaces) > 0 {
if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil {
log.Fatalf("Filtering interfaces failed: %v", err)
}
} else {
// No interfaces provided, process all interfaces for backward compatibility
log.Printf("No interfaces specified, processing all interfaces")
}

return pkg, nil
}

Expand Down Expand Up @@ -802,4 +817,36 @@ func packageNameOfDir(srcDir string) (string, error) {
return packageImport, nil
}

func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) {
// If no interfaces are requested, return all interfaces
if len(requested) == 0 {
return all, nil
}

requestedIfaces := make(map[string]struct{})
for _, iface := range requested {
requestedIfaces[iface] = struct{}{}
}

result := make([]*model.Interface, 0, len(requestedIfaces))
for _, iface := range all {
// Only add interfaces that are requested
if _, ok := requestedIfaces[iface.Name]; ok {
result = append(result, iface)
delete(requestedIfaces, iface.Name) // Remove matched iface from requested
}
}

// If any requested interfaces were not found, return an error
if len(requestedIfaces) > 0 {
var missing []string
for iface := range requestedIfaces {
missing = append(missing, iface)
}
return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", "))
}

return result, nil
}

var errOutsideGoPath = errors.New("source directory is outside GOPATH")
127 changes: 127 additions & 0 deletions mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package main
import (
"go/parser"
"go/token"
"reflect"
"testing"

"go.uber.org/mock/mockgen/model"
)

func TestFileParser_ParseFile(t *testing.T) {
Expand Down Expand Up @@ -143,3 +146,127 @@ func TestParseArrayWithConstLength(t *testing.T) {
}
}
}

func Test_filterInterfaces(t *testing.T) {
type args struct {
all []*model.Interface
requested []string
}
tests := []struct {
name string
args args
want []*model.Interface
wantErr bool
}{
{
name: "no filter (returns all interfaces)",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{},
},
want: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
wantErr: false,
},
{
name: "filter by Foo",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Foo"},
},
want: []*model.Interface{
{
Name: "Foo",
},
},
wantErr: false,
},
{
name: "filter by Foo and Bar",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Foo", "Bar"},
},
want: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
wantErr: false,
},
{
name: "incorrect filter by Foo and Baz",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Foo", "Baz"},
},
want: nil,
wantErr: true,
},
{
name: "missing interface (Baz not found)",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Baz"},
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := filterInterfaces(tt.args.all, tt.args.requested)
if (err != nil) != tt.wantErr {
t.Errorf("filterInterfaces() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("filterInterfaces() got = %v, want %v", got, tt.want)
}
})
}
}