imports: add support for vendor directories

For golang/go#12278 (original request).

Change-Id: I27b136041f54edcde4bf474215b48ebb0417f34d
diff --git a/imports/fix.go b/imports/fix.go
index 3ccee0e..a201f9e 100644
--- a/imports/fix.go
+++ b/imports/fix.go
@@ -45,7 +45,7 @@
 	return 0
 }
 
-func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
+func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []string, err error) {
 	// refs are a set of possible package references currently unsatisfied by imports.
 	// first key: either base package (e.g. "fmt") or renamed package
 	// second key: referenced package symbol (e.g. "Println")
@@ -117,7 +117,7 @@
 			continue // skip over packages already imported
 		}
 		go func(pkgName string, symbols map[string]bool) {
-			ipath, rename, err := findImport(pkgName, symbols)
+			ipath, rename, err := findImport(pkgName, symbols, filename)
 			r := result{ipath: ipath, err: err}
 			if rename {
 				r.name = pkgName
@@ -304,7 +304,7 @@
 // extended by adding a file with an init function.
 var findImport = findImportGoPath
 
-func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) {
+func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
 	// Fast path for the standard library.
 	// In the common case we hopefully never have to scan the GOPATH, which can
 	// be slow with moving disks.
@@ -320,51 +320,79 @@
 	pkgIndexOnce.Do(loadPkgIndex)
 
 	// Collect exports for packages with matching names.
-	var wg sync.WaitGroup
-	var pkgsMu sync.Mutex // guards pkgs
-	// full importpath => exported symbol => True
-	// e.g. "net/http" => "Client" => True
-	pkgs := make(map[string]map[string]bool)
+	var (
+		wg       sync.WaitGroup
+		mu       sync.Mutex
+		shortest string
+	)
 	pkgIndex.Lock()
 	for _, pkg := range pkgIndex.m[pkgName] {
+		if !canUse(filename, pkg.dir) {
+			continue
+		}
 		wg.Add(1)
 		go func(importpath, dir string) {
 			defer wg.Done()
 			exports := loadExports(dir)
-			if exports != nil {
-				pkgsMu.Lock()
-				pkgs[importpath] = exports
-				pkgsMu.Unlock()
+			if exports == nil {
+				return
 			}
+			// If it doesn't have the right symbols, stop.
+			for symbol := range symbols {
+				if !exports[symbol] {
+					return
+				}
+			}
+
+			// Devendorize for use in import statement.
+			if i := strings.LastIndex(importpath, "/vendor/"); i >= 0 {
+				importpath = importpath[i+len("/vendor/"):]
+			} else if strings.HasPrefix(importpath, "vendor/") {
+				importpath = importpath[len("vendor/"):]
+			}
+
+			// Save as the answer.
+			// If there are multiple candidates, the shortest wins,
+			// to prefer "bytes" over "github.com/foo/bytes".
+			mu.Lock()
+			if shortest == "" || len(importpath) < len(shortest) || len(importpath) == len(shortest) && importpath < shortest {
+				shortest = importpath
+			}
+			mu.Unlock()
 		}(pkg.importpath, pkg.dir)
 	}
 	pkgIndex.Unlock()
 	wg.Wait()
 
-	// Filter out packages missing required exported symbols.
-	for symbol := range symbols {
-		for importpath, exports := range pkgs {
-			if !exports[symbol] {
-				delete(pkgs, importpath)
-			}
-		}
-	}
-	if len(pkgs) == 0 {
-		return "", false, nil
-	}
-
-	// If there are multiple candidate packages, the shortest one wins.
-	// This is a heuristic to prefer the standard library (e.g. "bytes")
-	// over e.g. "github.com/foo/bar/bytes".
-	shortest := ""
-	for importPath := range pkgs {
-		if shortest == "" || len(importPath) < len(shortest) {
-			shortest = importPath
-		}
-	}
 	return shortest, false, nil
 }
 
+func canUse(filename, dir string) bool {
+	dirSlash := filepath.ToSlash(dir)
+	if !strings.Contains(dirSlash, "/vendor/") && !strings.Contains(dirSlash, "/internal/") && !strings.HasSuffix(dirSlash, "/internal") {
+		return true
+	}
+	// Vendor or internal directory only visible from children of parent.
+	// That means the path from the current directory to the target directory
+	// can contain ../vendor or ../internal but not ../foo/vendor or ../foo/internal
+	// or bar/vendor or bar/internal.
+	// After stripping all the leading ../, the only okay place to see vendor or internal
+	// is at the very beginning of the path.
+	abs, err := filepath.Abs(filename)
+	if err != nil {
+		return false
+	}
+	rel, err := filepath.Rel(abs, dir)
+	if err != nil {
+		return false
+	}
+	relSlash := filepath.ToSlash(rel)
+	if i := strings.LastIndex(relSlash, "../"); i >= 0 {
+		relSlash = relSlash[i+len("../"):]
+	}
+	return !strings.Contains(relSlash, "/vendor/") && !strings.Contains(relSlash, "/internal/") && !strings.HasSuffix(relSlash, "/internal")
+}
+
 type visitFn func(node ast.Node) ast.Visitor
 
 func (fn visitFn) Visit(node ast.Node) ast.Visitor {
diff --git a/imports/fix_test.go b/imports/fix_test.go
index f087bc7..3a5d7c2 100644
--- a/imports/fix_test.go
+++ b/imports/fix_test.go
@@ -10,6 +10,7 @@
 	"io/ioutil"
 	"os"
 	"path/filepath"
+	"runtime"
 	"sync"
 	"testing"
 )
@@ -743,7 +744,7 @@
 		"user":      "appengine/user",
 		"zip":       "archive/zip",
 	}
-	findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
+	findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
 		return simplePkgs[pkgName], pkgName == "str", nil
 	}
 
@@ -813,7 +814,7 @@
 		build.Default.GOPATH = oldGOPATH
 	}()
 
-	got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
+	got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -821,7 +822,7 @@
 		t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
 	}
 
-	got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
+	got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -830,6 +831,68 @@
 	}
 }
 
+func TestFindImportInternal(t *testing.T) {
+	pkgIndexOnce = sync.Once{}
+	oldGOPATH := build.Default.GOPATH
+	build.Default.GOPATH = ""
+	defer func() {
+		build.Default.GOPATH = oldGOPATH
+	}()
+
+	_, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/internal"))
+	if err != nil {
+		t.Skip(err)
+	}
+
+	got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got != "internal/race" || rename {
+		t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "internal/race", false`, got, rename)
+	}
+
+	// should not be able to use internal from outside that tree
+	got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got != "" || rename {
+		t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "", false`, got, rename)
+	}
+}
+
+func TestFindImportVendor(t *testing.T) {
+	pkgIndexOnce = sync.Once{}
+	oldGOPATH := build.Default.GOPATH
+	build.Default.GOPATH = ""
+	defer func() {
+		build.Default.GOPATH = oldGOPATH
+	}()
+
+	_, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor"))
+	if err != nil {
+		t.Skip(err)
+	}
+
+	got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got != "golang.org/x/net/http2/hpack" || rename {
+		t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "golang.org/x/net/http2/hpack", false`, got, rename)
+	}
+
+	// should not be able to use vendor from outside that tree
+	got, rename, err = findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "x.go"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got != "" || rename {
+		t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "", false`, got, rename)
+	}
+}
+
 func TestFindImportStdlib(t *testing.T) {
 	tests := []struct {
 		pkg     string
diff --git a/imports/imports.go b/imports/imports.go
index e30946b..fee0789 100644
--- a/imports/imports.go
+++ b/imports/imports.go
@@ -46,7 +46,7 @@
 		return nil, err
 	}
 
-	_, err = fixImports(fileSet, file)
+	_, err = fixImports(fileSet, file, filename)
 	if err != nil {
 		return nil, err
 	}