modfile: check canonicalness against the relevant module path, not abstract semver

In CL 279394 we started validating that versions in "exclude" and
"retract" directives are canonical. Unfortunately, we use the semver
package's notion of canonicalness, and the semver package doesn't know
anything about +incompatible versions or major-version suffixes.

The resulting error messages also don't indicate an appropriate fix if
the problem is that the user forgot either the "+incompatible" suffix
on the version string or the "/vN" suffix on the module path.

This change corrects both of those problems by validating the version
against the corresponding module path. (For "exclude" directives, that
is the module path to be excluded; for "retract" directives, it is the
module declared in the "module" directive of the same go.mod file.)

For golang/go#44497

Change-Id: I39732d79c3ab3a43bb1fb8905062fe6cb26d3edc
Reviewed-on: https://go-review.googlesource.com/c/mod/+/295089
Trust: Bryan C. Mills <bcmills@google.com>
Reviewed-by: Jay Conrod <jayconrod@google.com>
Reviewed-by: Michael Matloob <matloob@golang.org>
diff --git a/modfile/rule.go b/modfile/rule.go
index c6a189d..8fcf96b 100644
--- a/modfile/rule.go
+++ b/modfile/rule.go
@@ -835,11 +835,8 @@
 // AddExclude adds a exclude statement to the mod file. Errors if the provided
 // version is not a canonical version string
 func (f *File) AddExclude(path, vers string) error {
-	if !isCanonicalVersion(vers) {
-		return &module.InvalidVersionError{
-			Version: vers,
-			Err:     errors.New("must be of the form v1.2.3"),
-		}
+	if err := checkCanonicalVersion(path, vers); err != nil {
+		return err
 	}
 
 	var hint *Line
@@ -916,17 +913,15 @@
 // AddRetract adds a retract statement to the mod file. Errors if the provided
 // version interval does not consist of canonical version strings
 func (f *File) AddRetract(vi VersionInterval, rationale string) error {
-	if !isCanonicalVersion(vi.High) {
-		return &module.InvalidVersionError{
-			Version: vi.High,
-			Err:     errors.New("must be of the form v1.2.3"),
-		}
+	var path string
+	if f.Module != nil {
+		path = f.Module.Mod.Path
 	}
-	if !isCanonicalVersion(vi.Low) {
-		return &module.InvalidVersionError{
-			Version: vi.Low,
-			Err:     errors.New("must be of the form v1.2.3"),
-		}
+	if err := checkCanonicalVersion(path, vi.High); err != nil {
+		return err
+	}
+	if err := checkCanonicalVersion(path, vi.Low); err != nil {
+		return err
 	}
 
 	r := &Retract{
@@ -1086,8 +1081,40 @@
 	return semver.Compare(vii.High, vij.High) > 0
 }
 
-// isCanonicalVersion tests if the provided version string represents a valid
-// canonical version.
-func isCanonicalVersion(vers string) bool {
-	return vers != "" && semver.Canonical(vers) == vers
+// checkCanonicalVersion returns a non-nil error if vers is not a canonical
+// version string or does not match the major version of path.
+//
+// If path is non-empty, the error text suggests a format with a major version
+// corresponding to the path.
+func checkCanonicalVersion(path, vers string) error {
+	_, pathMajor, pathMajorOk := module.SplitPathVersion(path)
+
+	if vers == "" || vers != module.CanonicalVersion(vers) {
+		if pathMajor == "" {
+			return &module.InvalidVersionError{
+				Version: vers,
+				Err:     fmt.Errorf("must be of the form v1.2.3"),
+			}
+		}
+		return &module.InvalidVersionError{
+			Version: vers,
+			Err:     fmt.Errorf("must be of the form %s.2.3", module.PathMajorPrefix(pathMajor)),
+		}
+	}
+
+	if pathMajorOk {
+		if err := module.CheckPathMajor(vers, pathMajor); err != nil {
+			if pathMajor == "" {
+				// In this context, the user probably wrote "v2.3.4" when they meant
+				// "v2.3.4+incompatible". Suggest that instead of "v0 or v1".
+				return &module.InvalidVersionError{
+					Version: vers,
+					Err:     fmt.Errorf("should be %s+incompatible (or module %s/%v)", vers, path, semver.Major(vers)),
+				}
+			}
+			return err
+		}
+	}
+
+	return nil
 }
diff --git a/modfile/rule_test.go b/modfile/rule_test.go
index 03123ed..2381ee6 100644
--- a/modfile/rule_test.go
+++ b/modfile/rule_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"bytes"
+	"fmt"
 	"testing"
 
 	"golang.org/x/mod/module"
@@ -189,6 +190,45 @@
 	},
 }
 
+var addExcludeTests = []struct {
+	desc    string
+	in      string
+	path    string
+	version string
+	out     string
+}{
+	{
+		`compatible`,
+		`module m
+		`,
+		`example.com`,
+		`v1.2.3`,
+		`module m
+		exclude example.com v1.2.3
+		`,
+	},
+	{
+		`gopkg.in v0`,
+		`module m
+		`,
+		`gopkg.in/foo.v0`,
+		`v0.2.3`,
+		`module m
+		exclude gopkg.in/foo.v0 v0.2.3
+		`,
+	},
+	{
+		`gopkg.in v1`,
+		`module m
+		`,
+		`gopkg.in/foo.v1`,
+		`v1.2.3`,
+		`module m
+		exclude gopkg.in/foo.v1 v1.2.3
+		`,
+	},
+}
+
 var addRetractTests = []struct {
 	desc      string
 	in        string
@@ -569,44 +609,90 @@
 }
 
 var addRetractValidateVersionTests = []struct {
-	dsc, low, high string
+	desc      string
+	path      string
+	low, high string
+	wantErr   string
 }{
 	{
-		"blank_version",
-		"",
-		"",
+		`blank_version`,
+		`example.com/m`,
+		``,
+		``,
+		`version "" invalid: must be of the form v1.2.3`,
 	},
 	{
-		"missing_prefix",
-		"1.0.0",
-		"1.0.0",
+		`missing prefix`,
+		`example.com/m`,
+		`1.0.0`,
+		`1.0.0`,
+		`version "1.0.0" invalid: must be of the form v1.2.3`,
 	},
 	{
-		"non_canonical",
-		"v1.2",
-		"v1.2",
+		`non-canonical`,
+		`example.com/m`,
+		`v1.2`,
+		`v1.2`,
+		`version "v1.2" invalid: must be of the form v1.2.3`,
 	},
 	{
-		"invalid_range",
-		"v1.2.3",
-		"v1.3",
+		`invalid range`,
+		`example.com/m`,
+		`v1.2.3`,
+		`v1.3`,
+		`version "v1.3" invalid: must be of the form v1.2.3`,
+	},
+	{
+		`mismatched major`,
+		`example.com/m/v2`,
+		`v1.0.0`,
+		`v1.0.0`,
+		`version "v1.0.0" invalid: should be v2, not v1`,
+	},
+	{
+		`missing +incompatible`,
+		`example.com/m`,
+		`v2.0.0`,
+		`v2.0.0`,
+		`version "v2.0.0" invalid: should be v2.0.0+incompatible (or module example.com/m/v2)`,
 	},
 }
 
 var addExcludeValidateVersionTests = []struct {
-	dsc, ver string
+	desc    string
+	path    string
+	version string
+	wantErr string
 }{
 	{
-		"blank_version",
-		"",
+		`blank version`,
+		`example.com/m`,
+		``,
+		`version "" invalid: must be of the form v1.2.3`,
 	},
 	{
-		"missing_prefix",
-		"1.0.0",
+		`missing prefix`,
+		`example.com/m`,
+		`1.0.0`,
+		`version "1.0.0" invalid: must be of the form v1.2.3`,
 	},
 	{
-		"non_canonical",
-		"v1.2",
+		`non-canonical`,
+		`example.com/m`,
+		`v1.2`,
+		`version "v1.2" invalid: must be of the form v1.2.3`,
+	},
+	{
+		`mismatched major`,
+		`example.com/m/v2`,
+		`v1.2.3`,
+		`version "v1.2.3" invalid: should be v2, not v1`,
+	},
+	{
+		`missing +incompatible`,
+		`example.com/m`,
+		`v2.3.4`,
+		`version "v2.3.4" invalid: should be v2.3.4+incompatible (or module example.com/m/v2)`,
 	},
 }
 
@@ -657,6 +743,16 @@
 	}
 }
 
+func TestAddExclude(t *testing.T) {
+	for _, tt := range addExcludeTests {
+		t.Run(tt.desc, func(t *testing.T) {
+			testEdit(t, tt.in, tt.out, true, func(f *File) error {
+				return f.AddExclude(tt.path, tt.version)
+			})
+		})
+	}
+}
+
 func TestAddRetract(t *testing.T) {
 	for _, tt := range addRetractTests {
 		t.Run(tt.desc, func(t *testing.T) {
@@ -744,13 +840,21 @@
 
 func TestAddRetractValidateVersion(t *testing.T) {
 	for _, tt := range addRetractValidateVersionTests {
-		t.Run(tt.dsc, func(t *testing.T) {
-			f, err := Parse("in", []byte("module m"), nil)
-			if err != nil {
-				t.Fatal(err)
+		t.Run(tt.desc, func(t *testing.T) {
+			f := new(File)
+			if tt.path != "" {
+				if err := f.AddModuleStmt(tt.path); err != nil {
+					t.Fatal(err)
+				}
+				t.Logf("module %s", AutoQuote(tt.path))
 			}
-			if err = f.AddRetract(VersionInterval{Low: tt.low, High: tt.high}, ""); err == nil {
-				t.Fatal("expected AddRetract to complain about version format")
+			interval := VersionInterval{Low: tt.low, High: tt.high}
+			if err := f.AddRetract(interval, ``); err == nil || err.Error() != tt.wantErr {
+				errStr := "<nil>"
+				if err != nil {
+					errStr = fmt.Sprintf("%#q", err)
+				}
+				t.Fatalf("f.AddRetract(%+v, ``) = %s\nwant %#q", interval, errStr, tt.wantErr)
 			}
 		})
 	}
@@ -758,13 +862,17 @@
 
 func TestAddExcludeValidateVersion(t *testing.T) {
 	for _, tt := range addExcludeValidateVersionTests {
-		t.Run(tt.dsc, func(t *testing.T) {
+		t.Run(tt.desc, func(t *testing.T) {
 			f, err := Parse("in", []byte("module m"), nil)
 			if err != nil {
 				t.Fatal(err)
 			}
-			if err = f.AddExclude("aa", tt.ver); err == nil {
-				t.Fatal("expected AddExclude to complain about version format")
+			if err = f.AddExclude(tt.path, tt.version); err == nil || err.Error() != tt.wantErr {
+				errStr := "<nil>"
+				if err != nil {
+					errStr = fmt.Sprintf("%#q", err)
+				}
+				t.Fatalf("f.AddExclude(%q, %q) = %s\nwant %#q", tt.path, tt.version, errStr, tt.wantErr)
 			}
 		})
 	}