internal/lsp: unify go command invocation logic

We have two parallel code paths for Load and other go command
invocations. Unify them by introducing a mode argument to the various
functions that run the go command, indicating the purpose of the
invocation. That purpose can be used to infer what features should be
enabled.

In the future, I hope we can use the mode to decide whether mod
file updates and network access should be allowed.

Change-Id: I49c67fcefc9141287b78c56e9812ee6a8ac8378a
Reviewed-on: https://go-review.googlesource.com/c/tools/+/265238
Trust: Heschi Kreinick <heschi@google.com>
Run-TryBot: Heschi Kreinick <heschi@google.com>
gopls-CI: kokoro <noreply+kokoro@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/internal/lsp/cache/load.go b/internal/lsp/cache/load.go
index 12a09b6..0d07070 100644
--- a/internal/lsp/cache/load.go
+++ b/internal/lsp/cache/load.go
@@ -16,6 +16,7 @@
 
 	"golang.org/x/tools/go/packages"
 	"golang.org/x/tools/internal/event"
+	"golang.org/x/tools/internal/gocommand"
 	"golang.org/x/tools/internal/lsp/debug/tag"
 	"golang.org/x/tools/internal/lsp/source"
 	"golang.org/x/tools/internal/packagesinternal"
@@ -91,70 +92,14 @@
 	defer done()
 
 	cleanup := func() {}
-	wdir := s.view.rootURI.Filename()
 
-	var modFile string
-	var modURI span.URI
-	var modContent []byte
-	switch {
-	case s.workspaceMode()&usesWorkspaceModule != 0:
-		var (
-			tmpDir span.URI
-			err    error
-		)
-		tmpDir, cleanup, err = s.tempWorkspaceModule(ctx)
-		if err != nil {
-			return err
-		}
-		wdir = tmpDir.Filename()
-		modURI = span.URIFromPath(filepath.Join(wdir, "go.mod"))
-		modContent, err = ioutil.ReadFile(modURI.Filename())
-		if err != nil {
-			return err
-		}
-	case s.workspaceMode()&tempModfile != 0:
-		// -modfile is unsupported when there are > 1 modules in the workspace.
-		if len(s.modules) != 1 {
-			panic(fmt.Sprintf("unsupported use of -modfile, expected 1 module, got %v", len(s.modules)))
-		}
-		var mod *moduleRoot
-		for _, m := range s.modules { // range to access the only element
-			mod = m
-		}
-		modURI = mod.modURI
-		modFH, err := s.GetFile(ctx, mod.modURI)
-		if err != nil {
-			return err
-		}
-		modContent, err = modFH.Read()
-		if err != nil {
-			return err
-		}
-		var sumFH source.FileHandle
-		if mod.sumURI != "" {
-			sumFH, err = s.GetFile(ctx, mod.sumURI)
-			if err != nil {
-				return err
-			}
-		}
-		var tmpURI span.URI
-		tmpURI, cleanup, err = tempModFile(modFH, sumFH)
-		if err != nil {
-			return err
-		}
-		modFile = tmpURI.Filename()
-	}
-
-	cfg := s.config(ctx, wdir)
-	packagesinternal.SetModFile(cfg, modFile)
-	modMod, err := s.needsModEqualsMod(ctx, modURI, modContent)
+	_, inv, cleanup, err := s.goCommandInvocation(ctx, source.ForTypeChecking, &gocommand.Invocation{
+		WorkingDir: s.view.rootURI.Filename(),
+	})
 	if err != nil {
 		return err
 	}
-	if modMod {
-		packagesinternal.SetModFlag(cfg, "mod")
-	}
-
+	cfg := s.config(ctx, inv)
 	pkgs, err := packages.Load(cfg, query...)
 	cleanup()
 
diff --git a/internal/lsp/cache/mod.go b/internal/lsp/cache/mod.go
index 736b56b..f95590f 100644
--- a/internal/lsp/cache/mod.go
+++ b/internal/lsp/cache/mod.go
@@ -238,7 +238,7 @@
 		for _, req := range pm.File.Require {
 			inv.Args = append(inv.Args, req.Mod.Path)
 		}
-		stdout, err := snapshot.RunGoCommandDirect(ctx, inv)
+		stdout, err := snapshot.RunGoCommandDirect(ctx, source.Normal, inv)
 		if err != nil {
 			return &modWhyData{err: err}
 		}
@@ -336,7 +336,7 @@
 			// (see golang/go#38711).
 			inv.ModFlag = "readonly"
 		}
-		stdout, err := snapshot.RunGoCommandDirect(ctx, inv)
+		stdout, err := snapshot.RunGoCommandDirect(ctx, source.Normal, inv)
 		if err != nil {
 			return &modUpgradeData{err: err}
 		}
diff --git a/internal/lsp/cache/mod_tidy.go b/internal/lsp/cache/mod_tidy.go
index 375fddc..9e92287 100644
--- a/internal/lsp/cache/mod_tidy.go
+++ b/internal/lsp/cache/mod_tidy.go
@@ -57,9 +57,6 @@
 	if fh.Kind() != source.Mod {
 		return nil, fmt.Errorf("%s is not a go.mod file", fh.URI())
 	}
-	if s.workspaceMode()&tempModfile == 0 {
-		return nil, source.ErrTmpModfileUnsupported
-	}
 	if handle := s.getModTidyHandle(fh.URI()); handle != nil {
 		return handle.tidy(ctx, s)
 	}
@@ -118,7 +115,7 @@
 			Args:       []string{"tidy"},
 			WorkingDir: filepath.Dir(fh.URI().Filename()),
 		}
-		tmpURI, inv, cleanup, err := snapshot.goCommandInvocation(ctx, true, inv)
+		tmpURI, inv, cleanup, err := snapshot.goCommandInvocation(ctx, source.WriteTemporaryModFile, inv)
 		if err != nil {
 			return &modTidyData{err: err}
 		}
diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go
index 34d78a1..27b8ba1 100644
--- a/internal/lsp/cache/snapshot.go
+++ b/internal/lsp/cache/snapshot.go
@@ -182,18 +182,16 @@
 // TODO(rstambler): go/packages requires that we do not provide overlays for
 // multiple modules in on config, so buildOverlay needs to filter overlays by
 // module.
-func (s *snapshot) config(ctx context.Context, dir string) *packages.Config {
+func (s *snapshot) config(ctx context.Context, inv *gocommand.Invocation) *packages.Config {
 	s.view.optionsMu.Lock()
-	env := s.view.options.EnvSlice()
-	buildFlags := append([]string{}, s.view.options.BuildFlags...)
 	verboseOutput := s.view.options.VerboseOutput
 	s.view.optionsMu.Unlock()
 
 	cfg := &packages.Config{
 		Context:    ctx,
-		Dir:        dir,
-		Env:        append(append([]string{}, env...), "GO111MODULE="+s.view.go111module),
-		BuildFlags: buildFlags,
+		Dir:        inv.WorkingDir,
+		Env:        inv.Env,
+		BuildFlags: inv.BuildFlags,
 		Mode: packages.NeedName |
 			packages.NeedFiles |
 			packages.NeedCompiledGoFiles |
@@ -213,6 +211,8 @@
 		},
 		Tests: true,
 	}
+	packagesinternal.SetModFile(cfg, inv.ModFile)
+	packagesinternal.SetModFlag(cfg, inv.ModFlag)
 	// We want to type check cgo code if go/types supports it.
 	if typesinternal.SetUsesCgo(&types.Config{}) {
 		cfg.Mode |= packages.LoadMode(packagesinternal.TypecheckCgo)
@@ -221,8 +221,8 @@
 	return cfg
 }
 
-func (s *snapshot) RunGoCommandDirect(ctx context.Context, inv *gocommand.Invocation) (*bytes.Buffer, error) {
-	_, inv, cleanup, err := s.goCommandInvocation(ctx, false, inv)
+func (s *snapshot) RunGoCommandDirect(ctx context.Context, mode source.InvocationMode, inv *gocommand.Invocation) (*bytes.Buffer, error) {
+	_, inv, cleanup, err := s.goCommandInvocation(ctx, mode, inv)
 	if err != nil {
 		return nil, err
 	}
@@ -231,8 +231,8 @@
 	return s.view.session.gocmdRunner.Run(ctx, *inv)
 }
 
-func (s *snapshot) RunGoCommandPiped(ctx context.Context, inv *gocommand.Invocation, stdout, stderr io.Writer) error {
-	_, inv, cleanup, err := s.goCommandInvocation(ctx, true, inv)
+func (s *snapshot) RunGoCommandPiped(ctx context.Context, mode source.InvocationMode, inv *gocommand.Invocation, stdout, stderr io.Writer) error {
+	_, inv, cleanup, err := s.goCommandInvocation(ctx, mode, inv)
 	if err != nil {
 		return err
 	}
@@ -240,16 +240,49 @@
 	return s.view.session.gocmdRunner.RunPiped(ctx, *inv, stdout, stderr)
 }
 
-func (s *snapshot) goCommandInvocation(ctx context.Context, allowTempModfile bool, inv *gocommand.Invocation) (tmpURI span.URI, updatedInv *gocommand.Invocation, cleanup func(), err error) {
+func (s *snapshot) goCommandInvocation(ctx context.Context, mode source.InvocationMode, inv *gocommand.Invocation) (tmpURI span.URI, updatedInv *gocommand.Invocation, cleanup func(), err error) {
 	s.view.optionsMu.Lock()
-	env := s.view.options.EnvSlice()
+	inv.Env = append(append(append([]string{}, s.view.options.EnvSlice()...), inv.Env...), "GO111MODULE="+s.view.go111module)
+	inv.BuildFlags = append([]string{}, s.view.options.BuildFlags...)
 	s.view.optionsMu.Unlock()
-
 	cleanup = func() {} // fallback
-	inv.Env = append(append(append([]string{}, env...), inv.Env...), "GO111MODULE="+s.view.go111module)
 
-	modURI := s.GoModForFile(ctx, span.URIFromPath(inv.WorkingDir))
-	if allowTempModfile && s.workspaceMode()&tempModfile != 0 {
+	var modURI span.URI
+	if s.workspaceMode()&moduleMode != 0 {
+		// Select the module context to use.
+		// If we're type checking, we need to use the workspace context, meaning
+		// the main (workspace) module. Otherwise, we should use the module for
+		// the passed-in working dir.
+		if mode == source.ForTypeChecking {
+			if s.workspaceMode()&usesWorkspaceModule == 0 {
+				var mod *moduleRoot
+				for _, m := range s.modules { // range to access the only element
+					mod = m
+				}
+				modURI = mod.modURI
+			} else {
+				var tmpDir span.URI
+				var err error
+				tmpDir, cleanup, err = s.tempWorkspaceModule(ctx)
+				if err != nil {
+					return "", nil, cleanup, err
+				}
+				inv.WorkingDir = tmpDir.Filename()
+				modURI = span.URIFromPath(filepath.Join(tmpDir.Filename(), "go.mod"))
+			}
+		} else {
+			modURI = s.GoModForFile(ctx, span.URIFromPath(inv.WorkingDir))
+		}
+	}
+
+	wantTempMod := mode != source.UpdateUserModFile
+	needTempMod := mode == source.WriteTemporaryModFile
+	tempMod := wantTempMod && s.workspaceMode()&tempModfile != 0
+	if needTempMod && !tempMod {
+		return "", nil, cleanup, source.ErrTmpModfileUnsupported
+	}
+
+	if tempMod {
 		if modURI == "" {
 			return "", nil, cleanup, fmt.Errorf("no go.mod file found in %s", inv.WorkingDir)
 		}
diff --git a/internal/lsp/command.go b/internal/lsp/command.go
index 80e762c..134d3fd 100644
--- a/internal/lsp/command.go
+++ b/internal/lsp/command.go
@@ -265,7 +265,7 @@
 	}
 	snapshot, release := view.Snapshot(ctx)
 	defer release()
-	_, err = snapshot.RunGoCommandDirect(ctx, &gocommand.Invocation{
+	_, err = snapshot.RunGoCommandDirect(ctx, source.UpdateUserModFile, &gocommand.Invocation{
 		Verb:       verb,
 		Args:       args,
 		WorkingDir: filepath.Dir(uri.SpanURI().Filename()),
@@ -296,7 +296,7 @@
 			Args:       []string{pkgPath, "-v", "-count=1", "-run", fmt.Sprintf("^%s$", funcName)},
 			WorkingDir: filepath.Dir(uri.SpanURI().Filename()),
 		}
-		if err := snapshot.RunGoCommandPiped(ctx, inv, out, out); err != nil {
+		if err := snapshot.RunGoCommandPiped(ctx, source.Normal, inv, out, out); err != nil {
 			if errors.Is(err, context.Canceled) {
 				return err
 			}
@@ -312,7 +312,7 @@
 			Args:       []string{pkgPath, "-v", "-run=^$", "-bench", fmt.Sprintf("^%s$", funcName)},
 			WorkingDir: filepath.Dir(uri.SpanURI().Filename()),
 		}
-		if err := snapshot.RunGoCommandPiped(ctx, inv, out, out); err != nil {
+		if err := snapshot.RunGoCommandPiped(ctx, source.Normal, inv, out, out); err != nil {
 			if errors.Is(err, context.Canceled) {
 				return err
 			}
@@ -367,7 +367,7 @@
 		WorkingDir: dir.Filename(),
 	}
 	stderr := io.MultiWriter(er, workDoneWriter{work})
-	if err := snapshot.RunGoCommandPiped(ctx, inv, er, stderr); err != nil {
+	if err := snapshot.RunGoCommandPiped(ctx, source.Normal, inv, er, stderr); err != nil {
 		return err
 	}
 	return nil
diff --git a/internal/lsp/source/gc_annotations.go b/internal/lsp/source/gc_annotations.go
index 0a0d2b5..8bf8ad7 100644
--- a/internal/lsp/source/gc_annotations.go
+++ b/internal/lsp/source/gc_annotations.go
@@ -46,7 +46,7 @@
 		},
 		WorkingDir: pkgDir.Filename(),
 	}
-	_, err = snapshot.RunGoCommandDirect(ctx, inv)
+	_, err = snapshot.RunGoCommandDirect(ctx, Normal, inv)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go
index 2f3a212..8a12e0c 100644
--- a/internal/lsp/source/view.go
+++ b/internal/lsp/source/view.go
@@ -84,11 +84,11 @@
 
 	// RunGoCommandPiped runs the given `go` command, writing its output
 	// to stdout and stderr. Verb, Args, and WorkingDir must be specified.
-	RunGoCommandPiped(ctx context.Context, inv *gocommand.Invocation, stdout, stderr io.Writer) error
+	RunGoCommandPiped(ctx context.Context, mode InvocationMode, inv *gocommand.Invocation, stdout, stderr io.Writer) error
 
 	// RunGoCommandDirect runs the given `go` command. Verb, Args, and
 	// WorkingDir must be specified.
-	RunGoCommandDirect(ctx context.Context, inv *gocommand.Invocation) (*bytes.Buffer, error)
+	RunGoCommandDirect(ctx context.Context, mode InvocationMode, inv *gocommand.Invocation) (*bytes.Buffer, error)
 
 	// RunProcessEnvFunc runs fn with the process env for this snapshot's view.
 	// Note: the process env contains cached module and filesystem state.
@@ -169,6 +169,24 @@
 	WidestPackage
 )
 
+// InvocationMode represents the goal of a particular go command invocation.
+type InvocationMode int
+
+const (
+	// Normal is appropriate for commands that might be run by a user and don't
+	// deliberately modify go.mod files, e.g. `go test`.
+	Normal = iota
+	// UpdateUserModFile is for commands that intend to update the user's real
+	// go.mod file, e.g. `go mod tidy` in response to a user's request to tidy.
+	UpdateUserModFile
+	// WriteTemporaryModFile is for commands that need information from a
+	// modified version of the user's go.mod file, e.g. `go mod tidy` used to
+	// generate diagnostics.
+	WriteTemporaryModFile
+	// ForTypeChecking is for packages.Load.
+	ForTypeChecking
+)
+
 // View represents a single workspace.
 // This is the level at which we maintain configuration like working directory
 // and build tags.