[internal-branch.go1.16-vendor] modfile: defer fixing versions in retract directives

VersionFixers require both a path and a version: if the version is
non-canonical (like a branch name), they generally need the path to
look up the proper version. This is fine for require, replace, and
exclude directives, since the path is specified with each version. For
retract directives, the path comes from the module directive, which
may appear later in the file. Previously, we just used the empty
string, but this breaks reasonable implementations.

With this change, we leave retracted versions alone until the file has
been completely parsed, then we apply the version fixer to each
retract directive. We report an error if retract is used without a
module directive.

For golang/go#44496

Change-Id: I99b7b8b55941c1fde4ee56161acfe854bcaf948d
Reviewed-on: https://go-review.googlesource.com/c/mod/+/296130
Trust: Jay Conrod <jayconrod@google.com>
Run-TryBot: Jay Conrod <jayconrod@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Bryan C. Mills <bcmills@google.com>
(cherry picked from commit 66bf157bf5bcd4cf5e82f17680e12c7fc873a2c1)
Reviewed-on: https://go-review.googlesource.com/c/mod/+/298010
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
diff --git a/modfile/rule.go b/modfile/rule.go
index 8fcf96b..f8c9384 100644
--- a/modfile/rule.go
+++ b/modfile/rule.go
@@ -125,6 +125,12 @@
 
 type VersionFixer func(path, version string) (string, error)
 
+// errDontFix is returned by a VersionFixer to indicate the version should be
+// left alone, even if it's not canonical.
+var dontFixRetract VersionFixer = func(_, vers string) (string, error) {
+	return vers, nil
+}
+
 // Parse parses the data, reported in errors as being from file,
 // into a File struct. It applies fix, if non-nil, to canonicalize all module versions found.
 func Parse(file string, data []byte, fix VersionFixer) (*File, error) {
@@ -142,7 +148,7 @@
 	return parseToFile(file, data, fix, false)
 }
 
-func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (*File, error) {
+func parseToFile(file string, data []byte, fix VersionFixer, strict bool) (parsed *File, err error) {
 	fs, err := parse(file, data)
 	if err != nil {
 		return nil, err
@@ -150,8 +156,18 @@
 	f := &File{
 		Syntax: fs,
 	}
-
 	var errs ErrorList
+
+	// fix versions in retract directives after the file is parsed.
+	// We need the module path to fix versions, and it might be at the end.
+	defer func() {
+		oldLen := len(errs)
+		f.fixRetract(fix, &errs)
+		if len(errs) > oldLen {
+			parsed, err = nil, errs
+		}
+	}()
+
 	for _, x := range fs.Stmt {
 		switch x := x.(type) {
 		case *Line:
@@ -370,7 +386,7 @@
 
 	case "retract":
 		rationale := parseRetractRationale(block, line)
-		vi, err := parseVersionInterval(verb, &args, fix)
+		vi, err := parseVersionInterval(verb, "", &args, dontFixRetract)
 		if err != nil {
 			if strict {
 				wrapError(err)
@@ -397,6 +413,47 @@
 	}
 }
 
+// fixRetract applies fix to each retract directive in f, appending any errors
+// to errs.
+//
+// Most versions are fixed as we parse the file, but for retract directives,
+// the relevant module path is the one specified with the module directive,
+// and that might appear at the end of the file (or not at all).
+func (f *File) fixRetract(fix VersionFixer, errs *ErrorList) {
+	if fix == nil {
+		return
+	}
+	path := ""
+	if f.Module != nil {
+		path = f.Module.Mod.Path
+	}
+	var r *Retract
+	wrapError := func(err error) {
+		*errs = append(*errs, Error{
+			Filename: f.Syntax.Name,
+			Pos:      r.Syntax.Start,
+			Err:      err,
+		})
+	}
+
+	for _, r = range f.Retract {
+		if path == "" {
+			wrapError(errors.New("no module directive found, so retract cannot be used"))
+			return // only print the first one of these
+		}
+
+		args := r.Syntax.Token
+		if args[0] == "retract" {
+			args = args[1:]
+		}
+		vi, err := parseVersionInterval("retract", path, &args, fix)
+		if err != nil {
+			wrapError(err)
+		}
+		r.VersionInterval = vi
+	}
+}
+
 // isIndirect reports whether line has a "// indirect" comment,
 // meaning it is in go.mod only for its effect on indirect dependencies,
 // so that it can be dropped entirely once the effective version of the
@@ -491,13 +548,13 @@
 	return s
 }
 
-func parseVersionInterval(verb string, args *[]string, fix VersionFixer) (VersionInterval, error) {
+func parseVersionInterval(verb string, path string, args *[]string, fix VersionFixer) (VersionInterval, error) {
 	toks := *args
 	if len(toks) == 0 || toks[0] == "(" {
 		return VersionInterval{}, fmt.Errorf("expected '[' or version")
 	}
 	if toks[0] != "[" {
-		v, err := parseVersion(verb, "", &toks[0], fix)
+		v, err := parseVersion(verb, path, &toks[0], fix)
 		if err != nil {
 			return VersionInterval{}, err
 		}
@@ -509,7 +566,7 @@
 	if len(toks) == 0 {
 		return VersionInterval{}, fmt.Errorf("expected version after '['")
 	}
-	low, err := parseVersion(verb, "", &toks[0], fix)
+	low, err := parseVersion(verb, path, &toks[0], fix)
 	if err != nil {
 		return VersionInterval{}, err
 	}
@@ -523,7 +580,7 @@
 	if len(toks) == 0 {
 		return VersionInterval{}, fmt.Errorf("expected version after ','")
 	}
-	high, err := parseVersion(verb, "", &toks[0], fix)
+	high, err := parseVersion(verb, path, &toks[0], fix)
 	if err != nil {
 		return VersionInterval{}, err
 	}
@@ -631,8 +688,7 @@
 		}
 	}
 	if fix != nil {
-		var err error
-		t, err = fix(path, t)
+		fixed, err := fix(path, t)
 		if err != nil {
 			if err, ok := err.(*module.ModuleError); ok {
 				return "", &Error{
@@ -643,19 +699,23 @@
 			}
 			return "", err
 		}
+		t = fixed
+	} else {
+		cv := module.CanonicalVersion(t)
+		if cv == "" {
+			return "", &Error{
+				Verb:    verb,
+				ModPath: path,
+				Err: &module.InvalidVersionError{
+					Version: t,
+					Err:     errors.New("must be of the form v1.2.3"),
+				},
+			}
+		}
+		t = cv
 	}
-	if v := module.CanonicalVersion(t); v != "" {
-		*s = v
-		return *s, nil
-	}
-	return "", &Error{
-		Verb:    verb,
-		ModPath: path,
-		Err: &module.InvalidVersionError{
-			Version: t,
-			Err:     errors.New("must be of the form v1.2.3"),
-		},
-	}
+	*s = t
+	return *s, nil
 }
 
 func modulePathMajor(path string) (string, error) {
diff --git a/modfile/rule_test.go b/modfile/rule_test.go
index 2381ee6..96ef036 100644
--- a/modfile/rule_test.go
+++ b/modfile/rule_test.go
@@ -7,6 +7,7 @@
 import (
 	"bytes"
 	"fmt"
+	"strings"
 	"testing"
 
 	"golang.org/x/mod/module"
@@ -696,6 +697,59 @@
 	},
 }
 
+var fixVersionTests = []struct {
+	desc, in, want, wantErr string
+	fix                     VersionFixer
+}{
+	{
+		desc: `require`,
+		in:   `require example.com/m 1.0.0`,
+		want: `require example.com/m v1.0.0`,
+		fix:  fixV,
+	},
+	{
+		desc: `replace`,
+		in:   `replace example.com/m 1.0.0 => example.com/m 1.1.0`,
+		want: `replace example.com/m v1.0.0 => example.com/m v1.1.0`,
+		fix:  fixV,
+	},
+	{
+		desc: `exclude`,
+		in:   `exclude example.com/m 1.0.0`,
+		want: `exclude example.com/m v1.0.0`,
+		fix:  fixV,
+	},
+	{
+		desc: `retract_single`,
+		in: `module example.com/m
+		retract 1.0.0`,
+		want: `module example.com/m
+		retract v1.0.0`,
+		fix: fixV,
+	},
+	{
+		desc: `retract_interval`,
+		in: `module example.com/m
+		retract [1.0.0, 1.1.0]`,
+		want: `module example.com/m
+		retract [v1.0.0, v1.1.0]`,
+		fix: fixV,
+	},
+	{
+		desc:    `retract_nomod`,
+		in:      `retract 1.0.0`,
+		wantErr: `in:1: no module directive found, so retract cannot be used`,
+		fix:     fixV,
+	},
+}
+
+func fixV(path, version string) (string, error) {
+	if path != "example.com/m" {
+		return "", fmt.Errorf("module path must be example.com/m")
+	}
+	return "v" + version, nil
+}
+
 func TestAddRequire(t *testing.T) {
 	for _, tt := range addRequireTests {
 		t.Run(tt.desc, func(t *testing.T) {
@@ -877,3 +931,37 @@
 		})
 	}
 }
+
+func TestFixVersion(t *testing.T) {
+	for _, tt := range fixVersionTests {
+		t.Run(tt.desc, func(t *testing.T) {
+			inFile, err := Parse("in", []byte(tt.in), tt.fix)
+			if err != nil {
+				if tt.wantErr == "" {
+					t.Fatalf("unexpected error: %v", err)
+				}
+				if errMsg := err.Error(); !strings.Contains(errMsg, tt.wantErr) {
+					t.Fatalf("got error %q; want error containing %q", errMsg, tt.wantErr)
+				}
+				return
+			}
+			got, err := inFile.Format()
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			outFile, err := Parse("out", []byte(tt.want), nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+			want, err := outFile.Format()
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			if !bytes.Equal(got, want) {
+				t.Fatalf("got:\n%s\nwant:\n%s", got, want)
+			}
+		})
+	}
+}