go/loader: fix interaction of cwd and vendor

Packages specified on the command line should be interpreted relative to
cwd iff they are local (e.g. ./http within $GOROOT/src/net), otherwise a
request for, say, "golang.org/x/net/http2/hpack" might return the vendored
package depending on the working directory.

The FindPackage hook function now takes a build.ImportMode parameter, so
it matches the signature of (*build.Context).Import.  The AllowVendor
flag is enabled only for imports within source files, but not for the
initial packages.

+ test.

Change-Id: I756dc46b70928d2fd9f824e6670092d8169e0d64
Reviewed-on: https://go-review.googlesource.com/18318
Reviewed-by: Robert Griesemer <gri@golang.org>
diff --git a/go/loader/loader.go b/go/loader/loader.go
index f7bbd3a..7780fad 100644
--- a/go/loader/loader.go
+++ b/go/loader/loader.go
@@ -103,7 +103,7 @@
 	// "go build" layout conventions, for example.
 	//
 	// It must be safe to call concurrently from multiple goroutines.
-	FindPackage func(ctxt *build.Context, fromDir, importPath string) (*build.Package, error)
+	FindPackage func(ctxt *build.Context, fromDir, importPath string, mode build.ImportMode) (*build.Package, error)
 }
 
 // A PkgSpec specifies a non-importable package to be created by Load.
@@ -378,8 +378,8 @@
 	prog   *Program   // the resulting program
 
 	// findpkg is a memoization of FindPackage.
-	findpkgMu sync.Mutex                 // guards findpkg
-	findpkg   map[[2]string]findpkgValue // key is (fromDir, importPath)
+	findpkgMu sync.Mutex // guards findpkg
+	findpkg   map[findpkgKey]findpkgValue
 
 	importedMu sync.Mutex             // guards imported
 	imported   map[string]*importInfo // all imported packages (incl. failures) by import path
@@ -393,6 +393,12 @@
 	graph   map[string]map[string]bool
 }
 
+type findpkgKey struct {
+	importPath string
+	fromDir    string
+	mode       build.ImportMode
+}
+
 type findpkgValue struct {
 	bp  *build.Package
 	err error
@@ -461,9 +467,9 @@
 
 	// Install default FindPackage hook using go/build logic.
 	if conf.FindPackage == nil {
-		conf.FindPackage = func(ctxt *build.Context, fromDir, path string) (*build.Package, error) {
+		conf.FindPackage = func(ctxt *build.Context, path, fromDir string, mode build.ImportMode) (*build.Package, error) {
 			ioLimit <- true
-			bp, err := ctxt.Import(path, fromDir, buildutil.AllowVendor)
+			bp, err := ctxt.Import(path, fromDir, mode)
 			<-ioLimit
 			if _, ok := err.(*build.NoGoError); ok {
 				return bp, nil // empty directory is not an error
@@ -482,7 +488,7 @@
 	imp := importer{
 		conf:     conf,
 		prog:     prog,
-		findpkg:  make(map[[2]string]findpkgValue),
+		findpkg:  make(map[findpkgKey]findpkgValue),
 		imported: make(map[string]*importInfo),
 		start:    time.Now(),
 		graph:    make(map[string]map[string]bool),
@@ -494,7 +500,8 @@
 
 	// Load the initially imported packages and their dependencies,
 	// in parallel.
-	infos, importErrors := imp.importAll("", conf.Cwd, conf.ImportPkgs)
+	// No vendor check on packages imported from the command line.
+	infos, importErrors := imp.importAll("", conf.Cwd, conf.ImportPkgs, 0)
 	for _, ie := range importErrors {
 		conf.TypeChecker.Error(ie.err) // failed to create package
 		errpkgs = append(errpkgs, ie.path)
@@ -511,7 +518,8 @@
 			continue
 		}
 
-		bp, err := imp.findPackage(conf.Cwd, importPath)
+		// No vendor check on packages imported from command line.
+		bp, err := imp.findPackage(importPath, conf.Cwd, 0)
 		if err != nil {
 			// Package not found, or can't even parse package declaration.
 			// Already reported by previous loop; ignore it.
@@ -737,8 +745,6 @@
 // doImport imports the package denoted by path.
 // It implements the types.Importer signature.
 //
-// imports is the type-checker's package canonicalization map.
-//
 // It returns an error if a package could not be created
 // (e.g. go/build or parse error), but type errors are reported via
 // the types.Config.Error callback (the first of which is also saved
@@ -760,7 +766,7 @@
 			from.Pkg.Path())
 	}
 
-	bp, err := imp.findPackage(from.dir, to)
+	bp, err := imp.findPackage(to, from.dir, buildutil.AllowVendor)
 	if err != nil {
 		return nil, err
 	}
@@ -789,15 +795,15 @@
 
 // findPackage locates the package denoted by the importPath in the
 // specified directory.
-func (imp *importer) findPackage(fromDir, importPath string) (*build.Package, error) {
+func (imp *importer) findPackage(importPath, fromDir string, mode build.ImportMode) (*build.Package, error) {
 	// TODO(adonovan): opt: non-blocking duplicate-suppressing cache.
 	// i.e. don't hold the lock around FindPackage.
-	key := [2]string{fromDir, importPath}
+	key := findpkgKey{importPath, fromDir, mode}
 	imp.findpkgMu.Lock()
 	defer imp.findpkgMu.Unlock()
 	v, ok := imp.findpkg[key]
 	if !ok {
-		bp, err := imp.conf.FindPackage(imp.conf.build(), fromDir, importPath)
+		bp, err := imp.conf.FindPackage(imp.conf.build(), importPath, fromDir, mode)
 		v = findpkgValue{bp, err}
 		imp.findpkg[key] = v
 	}
@@ -813,12 +819,12 @@
 // fromDir is the directory containing the import declaration that
 // caused these imports.
 //
-func (imp *importer) importAll(fromPath, fromDir string, imports map[string]bool) (infos []*PackageInfo, errors []importError) {
+func (imp *importer) importAll(fromPath, fromDir string, imports map[string]bool, mode build.ImportMode) (infos []*PackageInfo, errors []importError) {
 	// TODO(adonovan): opt: do the loop in parallel once
 	// findPackage is non-blocking.
 	var pending []*importInfo
 	for importPath := range imports {
-		bp, err := imp.findPackage(fromDir, importPath)
+		bp, err := imp.findPackage(importPath, fromDir, mode)
 		if err != nil {
 			errors = append(errors, importError{
 				path: importPath,
@@ -958,7 +964,7 @@
 	}
 	// TODO(adonovan): opt: make the caller do scanImports.
 	// Callers with a build.Package can skip it.
-	imp.importAll(fromPath, info.dir, scanImports(files))
+	imp.importAll(fromPath, info.dir, scanImports(files), buildutil.AllowVendor)
 
 	if trace {
 		fmt.Fprintf(os.Stderr, "%s: start %q (%d)\n",
diff --git a/go/loader/loader_test.go b/go/loader/loader_test.go
index 3dbbfc3..bf0f5bb 100644
--- a/go/loader/loader_test.go
+++ b/go/loader/loader_test.go
@@ -431,6 +431,45 @@
 	}
 }
 
+func TestVendorCwd(t *testing.T) {
+	// Test the interaction of cwd and vendor directories.
+	ctxt := fakeContext(map[string]string{
+		"net":          ``, // mkdir net
+		"net/http":     `package http; import _ "hpack"`,
+		"vendor":       ``, // mkdir vendor
+		"vendor/hpack": `package vendorhpack`,
+		"hpack":        `package hpack`,
+	})
+	for i, test := range []struct {
+		cwd, arg, want string
+	}{
+		{cwd: "/go/src/net", arg: "http"}, // not found
+		{cwd: "/go/src/net", arg: "./http", want: "net/http vendor/hpack"},
+		{cwd: "/go/src/net", arg: "hpack", want: "hpack"},
+		{cwd: "/go/src/vendor", arg: "hpack", want: "hpack"},
+		{cwd: "/go/src/vendor", arg: "./hpack", want: "vendor/hpack"},
+	} {
+		conf := loader.Config{
+			Cwd:   test.cwd,
+			Build: ctxt,
+		}
+		conf.Import(test.arg)
+
+		var got string
+		prog, err := conf.Load()
+		if prog != nil {
+			got = strings.Join(all(prog), " ")
+		}
+		if got != test.want {
+			t.Errorf("#%d: Load(%s) from %s: got %s, want %s",
+				i, test.arg, test.cwd, got, test.want)
+			if err != nil {
+				t.Errorf("Load failed: %v", err)
+			}
+		}
+	}
+}
+
 // TODO(adonovan): more Load tests:
 //
 // failures: