[dev.cmdgo] cmd/go: fix calls to modFileGoVersion to pass in modFile

Before this change, we were arbitrarily picking a module to get the Go
version from in calls to modFileGoVersion. We now pass in the modFile to
modFileGoVersion when we have the file. Most of the calls were to get
the goVersion argument for commitRequirements, so now we have
commitRequirements call modFileGoVersion on the modFile directly
One of the calls to commitRequirements (when running go mod tidy with
a different Go version) passed in a new go version to update the file
to. Now, the modFile is updated before calling commitRequirements.

For the remaining cases of modFileGoVersion, it's replaced by a call to
the new (*MainModuleSet).GoVersion function, which either returns the go
version on the workspace file (in workspace mode) or the version of the
single go.mod file.

Change-Id: Ie88c3ca76c7f29ffc4faa16bb76f6cb7eccb5029
Reviewed-on: https://go-review.googlesource.com/c/go/+/344749
Trust: Michael Matloob <matloob@golang.org>
Run-TryBot: Michael Matloob <matloob@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Jay Conrod <jayconrod@google.com>
diff --git a/src/cmd/go/internal/modload/buildlist.go b/src/cmd/go/internal/modload/buildlist.go
index 14379b4..94414278 100644
--- a/src/cmd/go/internal/modload/buildlist.go
+++ b/src/cmd/go/internal/modload/buildlist.go
@@ -461,7 +461,7 @@
 		base.Fatalf("go: %v", err)
 	}
 
-	commitRequirements(ctx, modFileGoVersion(), rs)
+	commitRequirements(ctx, rs)
 	return mg
 }
 
@@ -527,7 +527,7 @@
 	if err != nil {
 		return false, err
 	}
-	commitRequirements(ctx, modFileGoVersion(), rs)
+	commitRequirements(ctx, rs)
 	return changed, err
 }
 
diff --git a/src/cmd/go/internal/modload/init.go b/src/cmd/go/internal/modload/init.go
index ab67338..b845842 100644
--- a/src/cmd/go/internal/modload/init.go
+++ b/src/cmd/go/internal/modload/init.go
@@ -59,8 +59,9 @@
 	// roots are required but MainModules hasn't been initialized yet. Set to
 	// the modRoots of the main modules.
 	// modRoots != nil implies len(modRoots) > 0
-	modRoots []string
-	gopath   string
+	modRoots          []string
+	gopath            string
+	workFileGoVersion string
 )
 
 // Variable set in InitWorkfile
@@ -91,6 +92,8 @@
 
 	modContainingCWD module.Version
 
+	workFileGoVersion string
+
 	indexMu sync.Mutex
 	indices map[module.Version]*modFileIndex
 }
@@ -190,6 +193,20 @@
 	return mms.modContainingCWD
 }
 
+// GoVersion returns the go version set on the single module, in module mode,
+// or the go.work file in workspace mode.
+func (mms *MainModuleSet) GoVersion() string {
+	if !inWorkspaceMode() {
+		return modFileGoVersion(mms.ModFile(mms.mustGetSingleMainModule()))
+	}
+	v := mms.workFileGoVersion
+	if v == "" {
+		// Fall back to 1.18 for go.work files.
+		v = "1.18"
+	}
+	return v
+}
+
 var MainModules *MainModuleSet
 
 type Root int
@@ -374,7 +391,7 @@
 
 	if inWorkspaceMode() {
 		var err error
-		modRoots, err = loadWorkFile(workFilePath)
+		workFileGoVersion, modRoots, err = loadWorkFile(workFilePath)
 		if err != nil {
 			base.Fatalf("reading go.work: %v", err)
 		}
@@ -536,16 +553,19 @@
 
 var errGoModDirty error = goModDirtyError{}
 
-func loadWorkFile(path string) (modRoots []string, err error) {
+func loadWorkFile(path string) (goVersion string, modRoots []string, err error) {
 	_ = TODOWorkspaces("Clean up and write back the go.work file: add module paths for workspace modules.")
 	workDir := filepath.Dir(path)
 	workData, err := lockedfile.Read(path)
 	if err != nil {
-		return nil, err
+		return "", nil, err
 	}
 	wf, err := modfile.ParseWork(path, workData, nil)
 	if err != nil {
-		return nil, err
+		return "", nil, err
+	}
+	if wf.Go != nil {
+		goVersion = wf.Go.Version
 	}
 	seen := map[string]bool{}
 	for _, d := range wf.Directory {
@@ -554,12 +574,12 @@
 			modRoot = filepath.Join(workDir, modRoot)
 		}
 		if seen[modRoot] {
-			return nil, fmt.Errorf("path %s appears multiple times in workspace", modRoot)
+			return "", nil, fmt.Errorf("path %s appears multiple times in workspace", modRoot)
 		}
 		seen[modRoot] = true
 		modRoots = append(modRoots, modRoot)
 	}
-	return modRoots, nil
+	return goVersion, modRoots, nil
 }
 
 // LoadModFile sets Target and, if there is a main module, parses the initial
@@ -582,7 +602,7 @@
 func LoadModFile(ctx context.Context) *Requirements {
 	rs, needCommit := loadModFile(ctx)
 	if needCommit {
-		commitRequirements(ctx, modFileGoVersion(), rs)
+		commitRequirements(ctx, rs)
 	}
 	return rs
 }
@@ -602,7 +622,7 @@
 	if len(modRoots) == 0 {
 		_ = TODOWorkspaces("Instead of creating a fake module with an empty modroot, make MainModules.Len() == 0 mean that we're in module mode but not inside any module.")
 		mainModule := module.Version{Path: "command-line-arguments"}
-		MainModules = makeMainModules([]module.Version{mainModule}, []string{""}, []*modfile.File{nil}, []*modFileIndex{nil})
+		MainModules = makeMainModules([]module.Version{mainModule}, []string{""}, []*modfile.File{nil}, []*modFileIndex{nil}, "")
 		goVersion := LatestGoVersion()
 		rawGoVersion.Store(mainModule, goVersion)
 		requirements = newRequirements(modDepthFromGoVersion(goVersion), nil, nil)
@@ -652,7 +672,7 @@
 		}
 	}
 
-	MainModules = makeMainModules(mainModules, modRoots, modFiles, indices)
+	MainModules = makeMainModules(mainModules, modRoots, modFiles, indices, workFileGoVersion)
 	setDefaultBuildMod() // possibly enable automatic vendoring
 	rs = requirementsFromModFiles(ctx, modFiles)
 
@@ -702,7 +722,7 @@
 				}
 			}
 		} else {
-			rawGoVersion.Store(mainModule, modFileGoVersion())
+			rawGoVersion.Store(mainModule, modFileGoVersion(MainModules.ModFile(mainModule)))
 		}
 	}
 
@@ -756,7 +776,7 @@
 	fmt.Fprintf(os.Stderr, "go: creating new go.mod: module %s\n", modPath)
 	modFile := new(modfile.File)
 	modFile.AddModuleStmt(modPath)
-	MainModules = makeMainModules([]module.Version{modFile.Module.Mod}, []string{modRoot}, []*modfile.File{modFile}, []*modFileIndex{nil})
+	MainModules = makeMainModules([]module.Version{modFile.Module.Mod}, []string{modRoot}, []*modfile.File{modFile}, []*modFileIndex{nil}, "")
 	addGoStmt(modFile, modFile.Module.Mod, LatestGoVersion()) // Add the go directive before converted module requirements.
 
 	convertedFrom, err := convertLegacyConfig(modFile, modRoot)
@@ -772,7 +792,7 @@
 	if err != nil {
 		base.Fatalf("go: %v", err)
 	}
-	commitRequirements(ctx, modFileGoVersion(), rs)
+	commitRequirements(ctx, rs)
 
 	// Suggest running 'go mod tidy' unless the project is empty. Even if we
 	// imported all the correct requirements above, we're probably missing
@@ -880,7 +900,7 @@
 
 // makeMainModules creates a MainModuleSet and associated variables according to
 // the given main modules.
-func makeMainModules(ms []module.Version, rootDirs []string, modFiles []*modfile.File, indices []*modFileIndex) *MainModuleSet {
+func makeMainModules(ms []module.Version, rootDirs []string, modFiles []*modfile.File, indices []*modFileIndex, workFileGoVersion string) *MainModuleSet {
 	for _, m := range ms {
 		if m.Version != "" {
 			panic("mainModulesCalled with module.Version with non empty Version field: " + fmt.Sprintf("%#v", m))
@@ -888,12 +908,13 @@
 	}
 	modRootContainingCWD := findModuleRoot(base.Cwd())
 	mainModules := &MainModuleSet{
-		versions:    ms[:len(ms):len(ms)],
-		inGorootSrc: map[module.Version]bool{},
-		pathPrefix:  map[module.Version]string{},
-		modRoot:     map[module.Version]string{},
-		modFiles:    map[module.Version]*modfile.File{},
-		indices:     map[module.Version]*modFileIndex{},
+		versions:          ms[:len(ms):len(ms)],
+		inGorootSrc:       map[module.Version]bool{},
+		pathPrefix:        map[module.Version]string{},
+		modRoot:           map[module.Version]string{},
+		modFiles:          map[module.Version]*modfile.File{},
+		indices:           map[module.Version]*modFileIndex{},
+		workFileGoVersion: workFileGoVersion,
 	}
 	for i, m := range ms {
 		mainModules.pathPrefix[m] = m.Path
@@ -958,7 +979,7 @@
 		}
 	}
 	module.Sort(roots)
-	rs := newRequirements(modDepthFromGoVersion(modFileGoVersion()), roots, direct)
+	rs := newRequirements(modDepthFromGoVersion(MainModules.GoVersion()), roots, direct)
 	return rs
 }
 
@@ -1315,12 +1336,13 @@
 	if !allowWriteGoMod {
 		panic("WriteGoMod called while disallowed")
 	}
-	commitRequirements(ctx, modFileGoVersion(), LoadModFile(ctx))
+	commitRequirements(ctx, LoadModFile(ctx))
 }
 
 // commitRequirements writes sets the global requirements variable to rs and
 // writes its contents back to the go.mod file on disk.
-func commitRequirements(ctx context.Context, goVersion string, rs *Requirements) {
+// goVersion, if non-empty, is used to set the version on the go.mod file.
+func commitRequirements(ctx context.Context, rs *Requirements) {
 	requirements = rs
 
 	if !allowWriteGoMod {
@@ -1352,10 +1374,10 @@
 			Indirect: !rs.direct[m.Path],
 		})
 	}
-	if goVersion != "" {
-		modFile.AddGoStmt(goVersion)
+	if modFile.Go == nil || modFile.Go.Version == "" {
+		modFile.AddGoStmt(modFileGoVersion(modFile))
 	}
-	if semver.Compare("v"+modFileGoVersion(), separateIndirectVersionV) < 0 {
+	if semver.Compare("v"+modFileGoVersion(modFile), separateIndirectVersionV) < 0 {
 		modFile.SetRequire(list)
 	} else {
 		modFile.SetRequireSeparateIndirect(list)
diff --git a/src/cmd/go/internal/modload/list.go b/src/cmd/go/internal/modload/list.go
index 1862bef..9c5018f 100644
--- a/src/cmd/go/internal/modload/list.go
+++ b/src/cmd/go/internal/modload/list.go
@@ -72,7 +72,7 @@
 	}
 
 	if err == nil {
-		commitRequirements(ctx, modFileGoVersion(), rs)
+		commitRequirements(ctx, rs)
 	}
 	return mods, err
 }
diff --git a/src/cmd/go/internal/modload/load.go b/src/cmd/go/internal/modload/load.go
index c9004ff..efe6ad1 100644
--- a/src/cmd/go/internal/modload/load.go
+++ b/src/cmd/go/internal/modload/load.go
@@ -407,11 +407,17 @@
 				base.Fatalf("go: %v", err)
 			}
 		}
+
+		// Update the go.mod file's Go version if necessary.
+		modFile := MainModules.ModFile(MainModules.mustGetSingleMainModule())
+		if ld.GoVersion != "" {
+			modFile.AddGoStmt(ld.GoVersion)
+		}
 	}
 
 	// Success! Update go.mod and go.sum (if needed) and return the results.
 	loaded = ld
-	commitRequirements(ctx, loaded.GoVersion, loaded.requirements)
+	commitRequirements(ctx, loaded.requirements)
 
 	for _, pkg := range ld.pkgs {
 		if !pkg.isTest() {
@@ -678,7 +684,7 @@
 			return roots
 		},
 	})
-	commitRequirements(ctx, loaded.GoVersion, loaded.requirements)
+	commitRequirements(ctx, loaded.requirements)
 }
 
 // DirImportPath returns the effective import path for dir,
@@ -960,7 +966,7 @@
 	}
 
 	if ld.GoVersion == "" {
-		ld.GoVersion = modFileGoVersion()
+		ld.GoVersion = MainModules.GoVersion()
 
 		if ld.Tidy && semver.Compare("v"+ld.GoVersion, "v"+LatestGoVersion()) > 0 {
 			ld.errorf("go mod tidy: go.mod file indicates go %s, but maximum supported version is %s\n", ld.GoVersion, LatestGoVersion())
@@ -1836,7 +1842,7 @@
 		fmt.Fprintln(os.Stderr)
 
 		goFlag := ""
-		if ld.GoVersion != modFileGoVersion() {
+		if ld.GoVersion != MainModules.GoVersion() {
 			goFlag = " -go=" + ld.GoVersion
 		}
 
diff --git a/src/cmd/go/internal/modload/modfile.go b/src/cmd/go/internal/modload/modfile.go
index 09e9c67..4638699 100644
--- a/src/cmd/go/internal/modload/modfile.go
+++ b/src/cmd/go/internal/modload/modfile.go
@@ -57,12 +57,8 @@
 )
 
 // modFileGoVersion returns the (non-empty) Go version at which the requirements
-// in modFile are intepreted, or the latest Go version if modFile is nil.
-func modFileGoVersion() string {
-	_ = TODOWorkspaces("this is obviously wrong.")
-	// Yes we're picking arbitrarily, we'll have to pass through the version
-	// we care about
-	modFile := MainModules.ModFile(MainModules.Versions()[0])
+// in modFile are interpreted, or the latest Go version if modFile is nil.
+func modFileGoVersion(modFile *modfile.File) string {
 	if modFile == nil {
 		return LatestGoVersion()
 	}