imports: remove globals, stop using build.Default

The imports package's public API is build.Default, but that doesn't mean
we need to use it in the internal implementation or the tests. Now we
have a new type, fixEnv, that contains everything relevant from
build.Context, as well as the various global vars that were only used
for testing.

Don't worry too much about the new function parameters; they mostly
move into the resolvers in the next CL.

Refactoring only; no user-visible changes intended.

Change-Id: I0d4c904955c5854dcdf904009cb3413c734baf88
Reviewed-on: https://go-review.googlesource.com/c/158437
Run-TryBot: Heschi Kreinick <heschi@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/imports/fix.go b/imports/fix.go
index 75c3b75..085d8aa 100644
--- a/imports/fix.go
+++ b/imports/fix.go
@@ -233,7 +233,7 @@
 	fset                 *token.FileSet // fset used to parse f and its siblings.
 	f                    *ast.File      // the file being fixed.
 	srcDir               string         // the directory containing f.
-	useGoPackages        bool           // use go/packages to load package information.
+	fixEnv               *fixEnv        // the environment to use for go commands, etc.
 	loadRealPackageNames bool           // if true, load package names from disk rather than guessing them.
 	otherFiles           []*ast.File    // sibling files.
 
@@ -258,9 +258,9 @@
 		unknown = append(unknown, imp.importPath)
 	}
 
-	if !p.useGoPackages {
+	if !p.fixEnv.shouldUseGoPackages() {
 		for _, path := range unknown {
-			name := importPathToName(path, p.srcDir)
+			name := importPathToName(p.fixEnv, path, p.srcDir)
 			if name == "" {
 				continue
 			}
@@ -272,7 +272,7 @@
 		return nil
 	}
 
-	cfg := newPackagesConfig(packages.LoadFiles)
+	cfg := p.fixEnv.newPackagesConfig(packages.LoadFiles)
 	pkgs, err := packages.Load(cfg, unknown...)
 	if err != nil {
 		return err
@@ -328,7 +328,9 @@
 	// f's imports by the identifier they introduce.
 	imports := collectImports(p.f)
 	if p.loadRealPackageNames {
-		p.loadPackageNames(append(imports, p.candidates...))
+		if err := p.loadPackageNames(append(imports, p.candidates...)); err != nil {
+			panic(err)
+		}
 	}
 	for _, imp := range imports {
 		p.existingImports[p.importIdentifier(imp)] = imp
@@ -448,7 +450,7 @@
 // easily be extended by adding a file with an init function.
 var fixImports = fixImportsDefault
 
-func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string) error {
+func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *fixEnv) error {
 	abs, err := filepath.Abs(filename)
 	if err != nil {
 		return err
@@ -462,7 +464,7 @@
 	// derive package names from import paths, see if the file is already
 	// complete. We can't add any imports yet, because we don't know
 	// if missing references are actually package vars.
-	p := &pass{fset: fset, f: f, srcDir: srcDir}
+	p := &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env}
 	if p.load() {
 		return nil
 	}
@@ -471,7 +473,7 @@
 
 	// Second pass: add information from other files in the same package,
 	// like their package vars and imports.
-	p = &pass{fset: fset, f: f, srcDir: srcDir}
+	p = &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env}
 	p.otherFiles = otherFiles
 	if p.load() {
 		return nil
@@ -484,13 +486,9 @@
 		return nil
 	}
 
-	// The only things that use go/packages happen in the third pass,
-	// so we can delay calling go env until this point.
-	useGoPackages := shouldUseGoPackages()
-
 	// Third pass: get real package names where we had previously used
 	// the naive algorithm.
-	p = &pass{fset: fset, f: f, srcDir: srcDir, useGoPackages: useGoPackages}
+	p = &pass{fset: fset, f: f, srcDir: srcDir, fixEnv: env}
 	p.loadRealPackageNames = true
 	p.otherFiles = otherFiles
 	if p.load() {
@@ -514,35 +512,66 @@
 	return nil
 }
 
-// Values controlling the use of go/packages, for testing only.
-var forceGoPackages, _ = strconv.ParseBool(os.Getenv("GOIMPORTSFORCEGOPACKAGES"))
-var goPackagesDir string
-var go111ModuleEnv string
+// fixEnv contains environment variables and settings that affect the use of
+// the go command, the go/build package, etc.
+type fixEnv struct {
+	// If non-empty, these will be used instead of the
+	// process-wide values.
+	GOPATH, GOROOT, GO111MODULE string
+	WorkingDir                  string
 
-func shouldUseGoPackages() bool {
-	if forceGoPackages {
+	// If true, use go/packages regardless of the environment.
+	ForceGoPackages bool
+
+	ranGoEnv bool
+	gomod    string
+}
+
+func (e *fixEnv) env() []string {
+	env := os.Environ()
+	add := func(k, v string) {
+		if v != "" {
+			env = append(env, k+"="+v)
+		}
+	}
+	add("GOPATH", e.GOPATH)
+	add("GOROOT", e.GOROOT)
+	add("GO111MODULE", e.GO111MODULE)
+	return env
+}
+
+func (e *fixEnv) shouldUseGoPackages() bool {
+	if e.ForceGoPackages {
 		return true
 	}
 
-	cmd := exec.Command("go", "env", "GOMOD")
-	cmd.Dir = goPackagesDir
-	out, err := cmd.Output()
-	if err != nil {
-		return false
+	if !e.ranGoEnv {
+		e.ranGoEnv = true
+		cmd := exec.Command("go", "env", "GOMOD")
+		cmd.Dir = e.WorkingDir
+		cmd.Env = e.env()
+		out, err := cmd.Output()
+		if err != nil {
+			return false
+		}
+		e.gomod = string(bytes.TrimSpace(out))
 	}
-	return len(bytes.TrimSpace(out)) > 0
+	return e.gomod != ""
 }
 
-func newPackagesConfig(mode packages.LoadMode) *packages.Config {
-	cfg := &packages.Config{
+func (e *fixEnv) newPackagesConfig(mode packages.LoadMode) *packages.Config {
+	return &packages.Config{
 		Mode: mode,
-		Dir:  goPackagesDir,
-		Env:  append(os.Environ(), "GOROOT="+build.Default.GOROOT, "GOPATH="+build.Default.GOPATH),
+		Dir:  e.WorkingDir,
+		Env:  e.env(),
 	}
-	if go111ModuleEnv != "" {
-		cfg.Env = append(cfg.Env, "GO111MODULE="+go111ModuleEnv)
-	}
-	return cfg
+}
+
+func (e *fixEnv) buildContext() *build.Context {
+	ctx := build.Default
+	ctx.GOROOT = e.GOROOT
+	ctx.GOPATH = e.GOPATH
+	return &ctx
 }
 
 func addStdlibCandidates(pass *pass, refs map[string]map[string]bool) {
@@ -566,13 +595,13 @@
 	}
 }
 
-func scanGoPackages(refs map[string]map[string]bool) ([]*pkg, error) {
+func scanGoPackages(env *fixEnv, refs map[string]map[string]bool) ([]*pkg, error) {
 	var loadQueries []string
 	for pkgName := range refs {
 		loadQueries = append(loadQueries, "name="+pkgName)
 	}
 	sort.Strings(loadQueries)
-	cfg := newPackagesConfig(packages.LoadFiles)
+	cfg := env.newPackagesConfig(packages.LoadFiles)
 	goPackages, err := packages.Load(cfg, loadQueries...)
 	if err != nil {
 		return nil, err
@@ -593,14 +622,14 @@
 
 func addExternalCandidatesDefault(pass *pass, refs map[string]map[string]bool, filename string) error {
 	var dirScan []*pkg
-	if pass.useGoPackages {
+	if pass.fixEnv.shouldUseGoPackages() {
 		var err error
-		dirScan, err = scanGoPackages(refs)
+		dirScan, err = scanGoPackages(pass.fixEnv, refs)
 		if err != nil {
 			return err
 		}
 	} else {
-		dirScan = scanGoDirs()
+		dirScan = scanGoDirs(pass.fixEnv)
 	}
 
 	// Search for imports matching potential package references.
@@ -625,7 +654,7 @@
 		go func(pkgName string, symbols map[string]bool) {
 			defer wg.Done()
 
-			found, err := findImport(ctx, dirScan, pkgName, symbols, filename)
+			found, err := findImport(ctx, pass.fixEnv, dirScan, pkgName, symbols, filename)
 
 			if err != nil {
 				firstErrOnce.Do(func() {
@@ -678,13 +707,13 @@
 
 // importPathToNameGoPath finds out the actual package name, as declared in its .go files.
 // If there's a problem, it returns "".
-func importPathToName(importPath, srcDir string) (packageName string) {
+func importPathToName(env *fixEnv, importPath, srcDir string) (packageName string) {
 	// Fast path for standard library without going to disk.
 	if _, ok := stdlib[importPath]; ok {
 		return path.Base(importPath) // stdlib packages always match their paths.
 	}
 
-	pkgName, err := importPathToNameGoPathParse(importPath, srcDir)
+	pkgName, err := importPathToNameGoPathParse(env, importPath, srcDir)
 	if Debug {
 		log.Printf("importPathToNameGoPathParse(%q, srcDir=%q) = %q, %v", importPath, srcDir, pkgName, err)
 	}
@@ -698,8 +727,8 @@
 // the only thing desired is the package name. It uses build.FindOnly
 // to find the directory and then only parses one file in the package,
 // trusting that the files in the directory are consistent.
-func importPathToNameGoPathParse(importPath, srcDir string) (packageName string, err error) {
-	buildPkg, err := build.Import(importPath, srcDir, build.FindOnly)
+func importPathToNameGoPathParse(env *fixEnv, importPath, srcDir string) (packageName string, err error) {
+	buildPkg, err := env.buildContext().Import(importPath, srcDir, build.FindOnly)
 	if err != nil {
 		return "", err
 	}
@@ -798,7 +827,7 @@
 }
 
 // scanGoDirs populates the dirScan map for GOPATH and GOROOT.
-func scanGoDirs() []*pkg {
+func scanGoDirs(env *fixEnv) []*pkg {
 	dupCheck := make(map[string]bool)
 	var result []*pkg
 
@@ -818,7 +847,7 @@
 			dir:             dir,
 		})
 	}
-	gopathwalk.Walk(gopathwalk.SrcDirsRoots(), add, gopathwalk.Options{Debug: Debug, ModulesEnabled: false})
+	gopathwalk.Walk(gopathwalk.SrcDirsRoots(env.buildContext()), add, gopathwalk.Options{Debug: Debug, ModulesEnabled: false})
 	return result
 }
 
@@ -837,7 +866,7 @@
 
 // loadExports returns the set of exported symbols in the package at dir.
 // It returns nil on error or if the package name in dir does not match expectPackage.
-func loadExports(ctx context.Context, expectPackage string, pkg *pkg) (map[string]bool, error) {
+func loadExports(ctx context.Context, env *fixEnv, expectPackage string, pkg *pkg) (map[string]bool, error) {
 	if Debug {
 		log.Printf("loading exports in dir %s (seeking package %s)", pkg.dir, expectPackage)
 	}
@@ -871,7 +900,7 @@
 		if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") {
 			continue
 		}
-		match, err := build.Default.MatchFile(pkg.dir, fi.Name())
+		match, err := env.buildContext().MatchFile(pkg.dir, fi.Name())
 		if err != nil || !match {
 			continue
 		}
@@ -924,7 +953,7 @@
 
 // findImport searches for a package with the given symbols.
 // If no package is found, findImport returns ("", false, nil)
-func findImport(ctx context.Context, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) {
+func findImport(ctx context.Context, env *fixEnv, dirScan []*pkg, pkgName string, symbols map[string]bool, filename string) (*pkg, error) {
 	pkgDir, err := filepath.Abs(filename)
 	if err != nil {
 		return nil, err
@@ -986,7 +1015,7 @@
 					wg.Done()
 				}()
 
-				exports, err := loadExports(ctx, pkgName, c.pkg)
+				exports, err := loadExports(ctx, env, pkgName, c.pkg)
 				if err != nil {
 					if Debug {
 						log.Printf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
diff --git a/imports/fix_test.go b/imports/fix_test.go
index 1006e62..dd9fe45 100644
--- a/imports/fix_test.go
+++ b/imports/fix_test.go
@@ -6,7 +6,6 @@
 
 import (
 	"fmt"
-	"go/build"
 	"path/filepath"
 	"runtime"
 	"strings"
@@ -1520,6 +1519,7 @@
 		t.Run(kind, func(t *testing.T) {
 			t.Helper()
 
+			forceGoPackages := false
 			var exporter packagestest.Exporter
 			switch kind {
 			case "GOPATH":
@@ -1545,30 +1545,15 @@
 				env[k] = v
 			}
 
-			goroot := env["GOROOT"]
-			gopath := env["GOPATH"]
-
-			oldGOPATH := build.Default.GOPATH
-			oldGOROOT := build.Default.GOROOT
-			oldCompiler := build.Default.Compiler
-			build.Default.GOROOT = goroot
-			build.Default.GOPATH = gopath
-			build.Default.Compiler = "gc"
-			goPackagesDir = exported.Config.Dir
-			go111ModuleEnv = env["GO111MODULE"]
-
-			defer func() {
-				build.Default.GOPATH = oldGOPATH
-				build.Default.GOROOT = oldGOROOT
-				build.Default.Compiler = oldCompiler
-				go111ModuleEnv = ""
-				goPackagesDir = ""
-				forceGoPackages = false
-			}()
-
 			it := &goimportTest{
-				T:        t,
-				gopath:   gopath,
+				T: t,
+				fixEnv: &fixEnv{
+					GOROOT:          env["GOROOT"],
+					GOPATH:          env["GOPATH"],
+					GO111MODULE:     env["GO111MODULE"],
+					WorkingDir:      exported.Config.Dir,
+					ForceGoPackages: forceGoPackages,
+				},
 				exported: exported,
 			}
 			fn(it)
@@ -1586,7 +1571,7 @@
 
 type goimportTest struct {
 	*testing.T
-	gopath   string
+	fixEnv   *fixEnv
 	exported *packagestest.Exported
 }
 
@@ -1596,7 +1581,7 @@
 	if f == "" {
 		t.Fatalf("%v not found in exported files (typo in filename?)", file)
 	}
-	buf, err := Process(f, contents, opts)
+	buf, err := process(f, contents, opts, t.fixEnv)
 	if err != nil {
 		t.Fatalf("Process() = %v", err)
 	}
@@ -1818,7 +1803,7 @@
 			},
 		},
 	}.test(t, func(t *goimportTest) {
-		got, err := importPathToNameGoPathParse("example.net/pkg", filepath.Join(t.gopath, "src", "other.net"))
+		got, err := importPathToNameGoPathParse(t.fixEnv, "example.net/pkg", filepath.Join(t.fixEnv.GOPATH, "src", "other.net"))
 		if err != nil {
 			t.Fatal(err)
 		}
diff --git a/imports/imports.go b/imports/imports.go
index 717a6f3..07101cb 100644
--- a/imports/imports.go
+++ b/imports/imports.go
@@ -13,6 +13,7 @@
 	"bytes"
 	"fmt"
 	"go/ast"
+	"go/build"
 	"go/format"
 	"go/parser"
 	"go/printer"
@@ -45,6 +46,11 @@
 // so it is important that filename be accurate.
 // To process data ``as if'' it were in filename, pass the data as a non-nil src.
 func Process(filename string, src []byte, opt *Options) ([]byte, error) {
+	env := &fixEnv{GOPATH: build.Default.GOPATH, GOROOT: build.Default.GOROOT}
+	return process(filename, src, opt, env)
+}
+
+func process(filename string, src []byte, opt *Options, env *fixEnv) ([]byte, error) {
 	if opt == nil {
 		opt = &Options{Comments: true, TabIndent: true, TabWidth: 8}
 	}
@@ -63,7 +69,7 @@
 	}
 
 	if !opt.FormatOnly {
-		if err := fixImports(fileSet, file, filename); err != nil {
+		if err := fixImports(fileSet, file, filename, env); err != nil {
 			return nil, err
 		}
 	}
diff --git a/internal/gopathwalk/walk.go b/internal/gopathwalk/walk.go
index a561f9f..488088b 100644
--- a/internal/gopathwalk/walk.go
+++ b/internal/gopathwalk/walk.go
@@ -44,10 +44,10 @@
 }
 
 // SrcDirsRoots returns the roots from build.Default.SrcDirs(). Not modules-compatible.
-func SrcDirsRoots() []Root {
+func SrcDirsRoots(ctx *build.Context) []Root {
 	var roots []Root
-	roots = append(roots, Root{filepath.Join(build.Default.GOROOT, "src"), RootGOROOT})
-	for _, p := range filepath.SplitList(build.Default.GOPATH) {
+	roots = append(roots, Root{filepath.Join(ctx.GOROOT, "src"), RootGOROOT})
+	for _, p := range filepath.SplitList(ctx.GOPATH) {
 		roots = append(roots, Root{filepath.Join(p, "src"), RootGOPATH})
 	}
 	return roots