imports: add support for vendor directories
For golang/go#12278 (original request).
Change-Id: I27b136041f54edcde4bf474215b48ebb0417f34d
diff --git a/imports/fix.go b/imports/fix.go
index 3ccee0e..a201f9e 100644
--- a/imports/fix.go
+++ b/imports/fix.go
@@ -45,7 +45,7 @@
return 0
}
-func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
+func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []string, err error) {
// refs are a set of possible package references currently unsatisfied by imports.
// first key: either base package (e.g. "fmt") or renamed package
// second key: referenced package symbol (e.g. "Println")
@@ -117,7 +117,7 @@
continue // skip over packages already imported
}
go func(pkgName string, symbols map[string]bool) {
- ipath, rename, err := findImport(pkgName, symbols)
+ ipath, rename, err := findImport(pkgName, symbols, filename)
r := result{ipath: ipath, err: err}
if rename {
r.name = pkgName
@@ -304,7 +304,7 @@
// extended by adding a file with an init function.
var findImport = findImportGoPath
-func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) {
+func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
// Fast path for the standard library.
// In the common case we hopefully never have to scan the GOPATH, which can
// be slow with moving disks.
@@ -320,51 +320,79 @@
pkgIndexOnce.Do(loadPkgIndex)
// Collect exports for packages with matching names.
- var wg sync.WaitGroup
- var pkgsMu sync.Mutex // guards pkgs
- // full importpath => exported symbol => True
- // e.g. "net/http" => "Client" => True
- pkgs := make(map[string]map[string]bool)
+ var (
+ wg sync.WaitGroup
+ mu sync.Mutex
+ shortest string
+ )
pkgIndex.Lock()
for _, pkg := range pkgIndex.m[pkgName] {
+ if !canUse(filename, pkg.dir) {
+ continue
+ }
wg.Add(1)
go func(importpath, dir string) {
defer wg.Done()
exports := loadExports(dir)
- if exports != nil {
- pkgsMu.Lock()
- pkgs[importpath] = exports
- pkgsMu.Unlock()
+ if exports == nil {
+ return
}
+ // If it doesn't have the right symbols, stop.
+ for symbol := range symbols {
+ if !exports[symbol] {
+ return
+ }
+ }
+
+ // Devendorize for use in import statement.
+ if i := strings.LastIndex(importpath, "/vendor/"); i >= 0 {
+ importpath = importpath[i+len("/vendor/"):]
+ } else if strings.HasPrefix(importpath, "vendor/") {
+ importpath = importpath[len("vendor/"):]
+ }
+
+ // Save as the answer.
+ // If there are multiple candidates, the shortest wins,
+ // to prefer "bytes" over "github.com/foo/bytes".
+ mu.Lock()
+ if shortest == "" || len(importpath) < len(shortest) || len(importpath) == len(shortest) && importpath < shortest {
+ shortest = importpath
+ }
+ mu.Unlock()
}(pkg.importpath, pkg.dir)
}
pkgIndex.Unlock()
wg.Wait()
- // Filter out packages missing required exported symbols.
- for symbol := range symbols {
- for importpath, exports := range pkgs {
- if !exports[symbol] {
- delete(pkgs, importpath)
- }
- }
- }
- if len(pkgs) == 0 {
- return "", false, nil
- }
-
- // If there are multiple candidate packages, the shortest one wins.
- // This is a heuristic to prefer the standard library (e.g. "bytes")
- // over e.g. "github.com/foo/bar/bytes".
- shortest := ""
- for importPath := range pkgs {
- if shortest == "" || len(importPath) < len(shortest) {
- shortest = importPath
- }
- }
return shortest, false, nil
}
+func canUse(filename, dir string) bool {
+ dirSlash := filepath.ToSlash(dir)
+ if !strings.Contains(dirSlash, "/vendor/") && !strings.Contains(dirSlash, "/internal/") && !strings.HasSuffix(dirSlash, "/internal") {
+ return true
+ }
+ // Vendor or internal directory only visible from children of parent.
+ // That means the path from the current directory to the target directory
+ // can contain ../vendor or ../internal but not ../foo/vendor or ../foo/internal
+ // or bar/vendor or bar/internal.
+ // After stripping all the leading ../, the only okay place to see vendor or internal
+ // is at the very beginning of the path.
+ abs, err := filepath.Abs(filename)
+ if err != nil {
+ return false
+ }
+ rel, err := filepath.Rel(abs, dir)
+ if err != nil {
+ return false
+ }
+ relSlash := filepath.ToSlash(rel)
+ if i := strings.LastIndex(relSlash, "../"); i >= 0 {
+ relSlash = relSlash[i+len("../"):]
+ }
+ return !strings.Contains(relSlash, "/vendor/") && !strings.Contains(relSlash, "/internal/") && !strings.HasSuffix(relSlash, "/internal")
+}
+
type visitFn func(node ast.Node) ast.Visitor
func (fn visitFn) Visit(node ast.Node) ast.Visitor {
diff --git a/imports/fix_test.go b/imports/fix_test.go
index f087bc7..3a5d7c2 100644
--- a/imports/fix_test.go
+++ b/imports/fix_test.go
@@ -10,6 +10,7 @@
"io/ioutil"
"os"
"path/filepath"
+ "runtime"
"sync"
"testing"
)
@@ -743,7 +744,7 @@
"user": "appengine/user",
"zip": "archive/zip",
}
- findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
+ findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
return simplePkgs[pkgName], pkgName == "str", nil
}
@@ -813,7 +814,7 @@
build.Default.GOPATH = oldGOPATH
}()
- got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
+ got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go")
if err != nil {
t.Fatal(err)
}
@@ -821,7 +822,7 @@
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
}
- got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
+ got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go")
if err != nil {
t.Fatal(err)
}
@@ -830,6 +831,68 @@
}
}
+func TestFindImportInternal(t *testing.T) {
+ pkgIndexOnce = sync.Once{}
+ oldGOPATH := build.Default.GOPATH
+ build.Default.GOPATH = ""
+ defer func() {
+ build.Default.GOPATH = oldGOPATH
+ }()
+
+ _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/internal"))
+ if err != nil {
+ t.Skip(err)
+ }
+
+ got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != "internal/race" || rename {
+ t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "internal/race", false`, got, rename)
+ }
+
+ // should not be able to use internal from outside that tree
+ got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != "" || rename {
+ t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "", false`, got, rename)
+ }
+}
+
+func TestFindImportVendor(t *testing.T) {
+ pkgIndexOnce = sync.Once{}
+ oldGOPATH := build.Default.GOPATH
+ build.Default.GOPATH = ""
+ defer func() {
+ build.Default.GOPATH = oldGOPATH
+ }()
+
+ _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor"))
+ if err != nil {
+ t.Skip(err)
+ }
+
+ got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != "golang.org/x/net/http2/hpack" || rename {
+ t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "golang.org/x/net/http2/hpack", false`, got, rename)
+ }
+
+ // should not be able to use vendor from outside that tree
+ got, rename, err = findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "x.go"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != "" || rename {
+ t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "", false`, got, rename)
+ }
+}
+
func TestFindImportStdlib(t *testing.T) {
tests := []struct {
pkg string
diff --git a/imports/imports.go b/imports/imports.go
index e30946b..fee0789 100644
--- a/imports/imports.go
+++ b/imports/imports.go
@@ -46,7 +46,7 @@
return nil, err
}
- _, err = fixImports(fileSet, file)
+ _, err = fixImports(fileSet, file, filename)
if err != nil {
return nil, err
}