Skip to content

Commit

Permalink
fix: define imported package's local name by using flags (#41)
Browse files Browse the repository at this point in the history
fixes [#646](golang/mock#646),
Hi, my friends.
Merging this pull request can fix old bugs related to defining the local
names of imported packages by using the -imports flag.
I have also added a test for that, and it works properly.
If you have any ideas or concerns regarding the PR, feel free to tell
me.

---------

Co-authored-by: Erfan Momeni <[email protected]>
Co-authored-by: n0trace <[email protected]>
Co-authored-by: utagawa kiki <[email protected]>
Co-authored-by: Sung Yoon Whang <[email protected]>
  • Loading branch information
5 people authored Aug 7, 2023
1 parent 89f1565 commit 2417c65
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 7 deletions.
13 changes: 13 additions & 0 deletions mockgen/internal/tests/defined_import_local_name/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package defined_import_local_name

import (
"bytes"
"context"
)

//go:generate mockgen -package defined_import_local_name -destination mock.go -source input.go -imports b_mock=bytes,c_mock=context

type WithImports interface {
Method1() bytes.Buffer
Method2() context.Context
}
68 changes: 68 additions & 0 deletions mockgen/internal/tests/defined_import_local_name/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 18 additions & 1 deletion mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ var (
writeSourceComment = flag.Bool("write_source_comment", true, "Writes original file (source mode) or interface names (reflect mode) comment if true.")
copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header")
typed = flag.Bool("typed", false, "Generate Type-safe 'Return', 'Do', 'DoAndReturn' function")
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")

debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
showVersion = flag.Bool("version", false, "Print version.")
Expand Down Expand Up @@ -318,6 +320,16 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac

packagesName := createPackageMap(sortedPaths)

definedImports := make(map[string]string, len(im))
if *imports != "" {
for _, kv := range strings.Split(*imports, ",") {
eq := strings.Index(kv, "=")
if k, v := kv[:eq], kv[eq+1:]; k != "." {
definedImports[v] = k
}
}
}

g.packageMap = make(map[string]string, len(im))
localNames := make(map[string]bool, len(im))
for _, pth := range sortedPaths {
Expand All @@ -329,9 +341,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
// Local names for an imported package can usually be the basename of the import path.
// A couple of situations don't permit that, such as duplicate local names
// (e.g. importing "html/template" and "text/template"), or where the basename is
// a keyword (e.g. "foo/case").
// a keyword (e.g. "foo/case") or when defining a name for that by using the -imports flag.
// try base0, base1, ...
pkgName := base

if _, ok := definedImports[base]; ok {
pkgName = definedImports[base]
}

i := 0
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
pkgName = base + strconv.Itoa(i)
Expand Down
6 changes: 0 additions & 6 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package main

import (
"errors"
"flag"
"fmt"
"go/ast"
"go/build"
Expand All @@ -36,11 +35,6 @@ import (
"go.uber.org/mock/mockgen/model"
)

var (
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
)

// sourceMode generates mocks via source file.
func sourceMode(source string) (*model.Package, error) {
srcDir, err := filepath.Abs(filepath.Dir(source))
Expand Down

0 comments on commit 2417c65

Please sign in to comment.