go/expect: add marker support for go.mod files

This change adds some basic marker support for go.mod files inside of go/expect. It requires all markers to be of the form "//@mark()", where mark can be anything. It is the same format as .go files, only difference is that it needs to have "//" since that is the only comment marker that go.mod files recognize.

Updates golang/go#36091

Change-Id: Ib9e325e01020181b8cee1c1be6bb257726ce913d
Reviewed-on: https://go-review.googlesource.com/c/tools/+/216838
Run-TryBot: Rohan Challa <rohan@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
diff --git a/go/expect/expect.go b/go/expect/expect.go
index ac9975c..752f36a 100644
--- a/go/expect/expect.go
+++ b/go/expect/expect.go
@@ -56,6 +56,7 @@
 	"bytes"
 	"fmt"
 	"go/token"
+	"path/filepath"
 	"regexp"
 )
 
@@ -89,12 +90,20 @@
 		return token.NoPos, token.NoPos, fmt.Errorf("invalid file: %v", err)
 	}
 	position := f.Position(end)
-	startOffset := f.Offset(lineStart(f, position.Line))
+	startOffset := f.Offset(f.LineStart(position.Line))
 	endOffset := f.Offset(end)
 	line := content[startOffset:endOffset]
 	matchStart, matchEnd := -1, -1
 	switch pattern := pattern.(type) {
 	case string:
+		// If the file is a go.mod and we are matching // indirect, then we
+		// need to look for it on the line after the current line.
+		// TODO(golang/go#36894): have a more intuitive approach for // indirect
+		if filepath.Ext(f.Name()) == ".mod" && pattern == "// indirect" {
+			startOffset = f.Offset(f.LineStart(position.Line + 1))
+			endOffset = f.Offset(lineEnd(f, position.Line+1))
+			line = content[startOffset:endOffset]
+		}
 		bytePattern := []byte(pattern)
 		matchStart = bytes.Index(line, bytePattern)
 		if matchStart >= 0 {
@@ -118,32 +127,9 @@
 	return f.Pos(startOffset + matchStart), f.Pos(startOffset + matchEnd), nil
 }
 
-// this functionality was borrowed from the analysisutil package
-func lineStart(f *token.File, line int) token.Pos {
-	// Use binary search to find the start offset of this line.
-	//
-	// TODO(adonovan): eventually replace this function with the
-	// simpler and more efficient (*go/token.File).LineStart, added
-	// in go1.12.
-
-	min := 0        // inclusive
-	max := f.Size() // exclusive
-	for {
-		offset := (min + max) / 2
-		pos := f.Pos(offset)
-		posn := f.Position(pos)
-		if posn.Line == line {
-			return pos - (token.Pos(posn.Column) - 1)
-		}
-
-		if min+1 >= max {
-			return token.NoPos
-		}
-
-		if posn.Line < line {
-			min = offset
-		} else {
-			max = offset
-		}
+func lineEnd(f *token.File, line int) token.Pos {
+	if line >= f.LineCount() {
+		return token.Pos(f.Size() + 1)
 	}
+	return f.LineStart(line + 1)
 }
diff --git a/go/expect/expect_test.go b/go/expect/expect_test.go
index c00a0d7..bd6e437 100644
--- a/go/expect/expect_test.go
+++ b/go/expect/expect_test.go
@@ -14,102 +14,127 @@
 )
 
 func TestMarker(t *testing.T) {
-	const filename = "testdata/test.go"
-	content, err := ioutil.ReadFile(filename)
-	if err != nil {
-		t.Fatal(err)
-	}
+	for _, tt := range []struct {
+		filename      string
+		expectNotes   int
+		expectMarkers map[string]string
+		expectChecks  map[string][]interface{}
+	}{
+		{
+			filename:    "testdata/test.go",
+			expectNotes: 13,
+			expectMarkers: map[string]string{
+				"αSimpleMarker": "α",
+				"OffsetMarker":  "β",
+				"RegexMarker":   "γ",
+				"εMultiple":     "ε",
+				"ζMarkers":      "ζ",
+				"ηBlockMarker":  "η",
+				"Declared":      "η",
+				"Comment":       "ι",
+				"LineComment":   "someFunc",
+				"NonIdentifier": "+",
+				"StringMarker":  "\"hello\"",
+			},
+			expectChecks: map[string][]interface{}{
+				"αSimpleMarker": nil,
+				"StringAndInt":  []interface{}{"Number %d", int64(12)},
+				"Bool":          []interface{}{true},
+			},
+		},
+		{
+			filename:    "testdata/go.mod",
+			expectNotes: 3,
+			expectMarkers: map[string]string{
+				"αMarker":        "αfake1α",
+				"IndirectMarker": "// indirect",
+				"βMarker":        "require golang.org/modfile v0.0.0",
+			},
+		},
+	} {
+		t.Run(tt.filename, func(t *testing.T) {
+			content, err := ioutil.ReadFile(tt.filename)
+			if err != nil {
+				t.Fatal(err)
+			}
+			readFile := func(string) ([]byte, error) { return content, nil }
 
-	const expectNotes = 13
-	expectMarkers := map[string]string{
-		"αSimpleMarker": "α",
-		"OffsetMarker":  "β",
-		"RegexMarker":   "γ",
-		"εMultiple":     "ε",
-		"ζMarkers":      "ζ",
-		"ηBlockMarker":  "η",
-		"Declared":      "η",
-		"Comment":       "ι",
-		"LineComment":   "someFunc",
-		"NonIdentifier": "+",
-		"StringMarker":  "\"hello\"",
-	}
-	expectChecks := map[string][]interface{}{
-		"αSimpleMarker": nil,
-		"StringAndInt":  []interface{}{"Number %d", int64(12)},
-		"Bool":          []interface{}{true},
-	}
-
-	readFile := func(string) ([]byte, error) { return content, nil }
-	markers := make(map[string]token.Pos)
-	for name, tok := range expectMarkers {
-		offset := bytes.Index(content, []byte(tok))
-		markers[name] = token.Pos(offset + 1)
-		end := bytes.Index(content[offset:], []byte(tok))
-		if end > 0 {
-			markers[name+"@"] = token.Pos(offset + end + 2)
-		}
-	}
-
-	fset := token.NewFileSet()
-	notes, err := expect.Parse(fset, filename, nil)
-	if err != nil {
-		t.Fatalf("Failed to extract notes: %v", err)
-	}
-	if len(notes) != expectNotes {
-		t.Errorf("Expected %v notes, got %v", expectNotes, len(notes))
-	}
-	for _, n := range notes {
-		switch {
-		case n.Args == nil:
-			// A //@foo note associates the name foo with the position of the
-			// first match of "foo" on the current line.
-			checkMarker(t, fset, readFile, markers, n.Pos, n.Name, n.Name)
-		case n.Name == "mark":
-			// A //@mark(name, "pattern") note associates the specified name
-			// with the position on the first match of pattern on the current line.
-			if len(n.Args) != 2 {
-				t.Errorf("%v: expected 2 args to mark, got %v", fset.Position(n.Pos), len(n.Args))
-				continue
-			}
-			ident, ok := n.Args[0].(expect.Identifier)
-			if !ok {
-				t.Errorf("%v: identifier, got %T", fset.Position(n.Pos), n.Args[0])
-				continue
-			}
-			checkMarker(t, fset, readFile, markers, n.Pos, string(ident), n.Args[1])
-
-		case n.Name == "check":
-			// A //@check(args, ...) note specifies some hypothetical action to
-			// be taken by the test driver and its expected outcome.
-			// In this test, the action is to compare the arguments
-			// against expectChecks.
-			if len(n.Args) < 1 {
-				t.Errorf("%v: expected 1 args to check, got %v", fset.Position(n.Pos), len(n.Args))
-				continue
-			}
-			ident, ok := n.Args[0].(expect.Identifier)
-			if !ok {
-				t.Errorf("%v: identifier, got %T", fset.Position(n.Pos), n.Args[0])
-				continue
-			}
-			args, ok := expectChecks[string(ident)]
-			if !ok {
-				t.Errorf("%v: unexpected check %v", fset.Position(n.Pos), ident)
-				continue
-			}
-			if len(n.Args) != len(args)+1 {
-				t.Errorf("%v: expected %v args to check, got %v", fset.Position(n.Pos), len(args)+1, len(n.Args))
-				continue
-			}
-			for i, got := range n.Args[1:] {
-				if args[i] != got {
-					t.Errorf("%v: arg %d expected %v, got %v", fset.Position(n.Pos), i, args[i], got)
+			markers := make(map[string]token.Pos)
+			for name, tok := range tt.expectMarkers {
+				offset := bytes.Index(content, []byte(tok))
+				// Handle special case where we look for // indirect and we
+				// need to search the next line.
+				if tok == "// indirect" {
+					offset = bytes.Index(content, []byte(" "+tok)) + 1
+				}
+				markers[name] = token.Pos(offset + 1)
+				end := bytes.Index(content[offset:], []byte(tok))
+				if end > 0 {
+					markers[name+"@"] = token.Pos(offset + end + 2)
 				}
 			}
-		default:
-			t.Errorf("Unexpected note %v at %v", n.Name, fset.Position(n.Pos))
-		}
+
+			fset := token.NewFileSet()
+			notes, err := expect.Parse(fset, tt.filename, content)
+			if err != nil {
+				t.Fatalf("Failed to extract notes: %v", err)
+			}
+			if len(notes) != tt.expectNotes {
+				t.Errorf("Expected %v notes, got %v", tt.expectNotes, len(notes))
+			}
+			for _, n := range notes {
+				switch {
+				case n.Args == nil:
+					// A //@foo note associates the name foo with the position of the
+					// first match of "foo" on the current line.
+					checkMarker(t, fset, readFile, markers, n.Pos, n.Name, n.Name)
+				case n.Name == "mark":
+					// A //@mark(name, "pattern") note associates the specified name
+					// with the position on the first match of pattern on the current line.
+					if len(n.Args) != 2 {
+						t.Errorf("%v: expected 2 args to mark, got %v", fset.Position(n.Pos), len(n.Args))
+						continue
+					}
+					ident, ok := n.Args[0].(expect.Identifier)
+					if !ok {
+						t.Errorf("%v: identifier, got %T", fset.Position(n.Pos), n.Args[0])
+						continue
+					}
+					checkMarker(t, fset, readFile, markers, n.Pos, string(ident), n.Args[1])
+
+				case n.Name == "check":
+					// A //@check(args, ...) note specifies some hypothetical action to
+					// be taken by the test driver and its expected outcome.
+					// In this test, the action is to compare the arguments
+					// against expectChecks.
+					if len(n.Args) < 1 {
+						t.Errorf("%v: expected 1 args to check, got %v", fset.Position(n.Pos), len(n.Args))
+						continue
+					}
+					ident, ok := n.Args[0].(expect.Identifier)
+					if !ok {
+						t.Errorf("%v: identifier, got %T", fset.Position(n.Pos), n.Args[0])
+						continue
+					}
+					args, ok := tt.expectChecks[string(ident)]
+					if !ok {
+						t.Errorf("%v: unexpected check %v", fset.Position(n.Pos), ident)
+						continue
+					}
+					if len(n.Args) != len(args)+1 {
+						t.Errorf("%v: expected %v args to check, got %v", fset.Position(n.Pos), len(args)+1, len(n.Args))
+						continue
+					}
+					for i, got := range n.Args[1:] {
+						if args[i] != got {
+							t.Errorf("%v: arg %d expected %v, got %v", fset.Position(n.Pos), i, args[i], got)
+						}
+					}
+				default:
+					t.Errorf("Unexpected note %v at %v", n.Name, fset.Position(n.Pos))
+				}
+			}
+		})
 	}
 }
 
diff --git a/go/expect/extract.go b/go/expect/extract.go
index 249369f..67156cf 100644
--- a/go/expect/extract.go
+++ b/go/expect/extract.go
@@ -9,15 +9,17 @@
 	"go/ast"
 	"go/parser"
 	"go/token"
+	"path/filepath"
 	"regexp"
 	"strconv"
 	"strings"
 	"text/scanner"
+
+	"golang.org/x/mod/modfile"
 )
 
-const (
-	commentStart = "@"
-)
+const commentStart = "@"
+const commentStartLen = len(commentStart)
 
 // Identifier is the type for an identifier in an Note argument list.
 type Identifier string
@@ -34,52 +36,62 @@
 	if content != nil {
 		src = content
 	}
-	// TODO: We should write this in terms of the scanner.
-	// there are ways you can break the parser such that it will not add all the
-	// comments to the ast, which may result in files where the tests are silently
-	// not run.
-	file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
-	if file == nil {
-		return nil, err
+	switch filepath.Ext(filename) {
+	case ".go":
+		// TODO: We should write this in terms of the scanner.
+		// there are ways you can break the parser such that it will not add all the
+		// comments to the ast, which may result in files where the tests are silently
+		// not run.
+		file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
+		if file == nil {
+			return nil, err
+		}
+		return ExtractGo(fset, file)
+	case ".mod":
+		file, err := modfile.Parse(filename, content, nil)
+		if err != nil {
+			return nil, err
+		}
+		fset.AddFile(filename, -1, len(content)).SetLinesForContent(content)
+		return extractMod(fset, file)
 	}
-	return Extract(fset, file)
+	return nil, nil
 }
 
-// Extract collects all the notes present in an AST.
+// extractMod collects all the notes present in a go.mod file.
 // Each comment whose text starts with @ is parsed as a comma-separated
 // sequence of notes.
 // See the package documentation for details about the syntax of those
 // notes.
-func Extract(fset *token.FileSet, file *ast.File) ([]*Note, error) {
+// Only allow notes to appear with the following format: "//@mark()" or // @mark()
+func extractMod(fset *token.FileSet, file *modfile.File) ([]*Note, error) {
 	var notes []*Note
-	for _, g := range file.Comments {
-		for _, c := range g.List {
-			text := c.Text
-			if strings.HasPrefix(text, "/*") {
-				text = strings.TrimSuffix(text, "*/")
-			}
-			text = text[2:] // remove "//" or "/*" prefix
-
-			// Allow notes to appear within comments.
-			// For example:
-			// "// //@mark()" is valid.
-			// "// @mark()" is not valid.
-			// "// /*@mark()*/" is not valid.
-			var adjust int
-			if i := strings.Index(text, commentStart); i > 2 {
-				// Get the text before the commentStart.
-				pre := text[i-2 : i]
-				if pre != "//" {
-					continue
-				}
-				text = text[i:]
-				adjust = i
-			}
-			if !strings.HasPrefix(text, commentStart) {
+	for _, stmt := range file.Syntax.Stmt {
+		comment := stmt.Comment()
+		if comment == nil {
+			continue
+		}
+		// Handle the case for markers of `// indirect` to be on the line before
+		// the require statement.
+		// TODO(golang/go#36894): have a more intuitive approach for // indirect
+		for _, cmt := range comment.Before {
+			text, adjust := getAdjustedNote(cmt.Token)
+			if text == "" {
 				continue
 			}
-			text = text[len(commentStart):]
-			parsed, err := parse(fset, token.Pos(int(c.Pos())+4+adjust), text)
+			parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text)
+			if err != nil {
+				return nil, err
+			}
+			notes = append(notes, parsed...)
+		}
+		// Handle the normal case for markers on the same line.
+		for _, cmt := range comment.Suffix {
+			text, adjust := getAdjustedNote(cmt.Token)
+			if text == "" {
+				continue
+			}
+			parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text)
 			if err != nil {
 				return nil, err
 			}
@@ -89,6 +101,57 @@
 	return notes, nil
 }
 
+// ExtractGo collects all the notes present in an AST.
+// Each comment whose text starts with @ is parsed as a comma-separated
+// sequence of notes.
+// See the package documentation for details about the syntax of those
+// notes.
+func ExtractGo(fset *token.FileSet, file *ast.File) ([]*Note, error) {
+	var notes []*Note
+	for _, g := range file.Comments {
+		for _, c := range g.List {
+			text, adjust := getAdjustedNote(c.Text)
+			if text == "" {
+				continue
+			}
+			parsed, err := parse(fset, token.Pos(int(c.Pos())+adjust), text)
+			if err != nil {
+				return nil, err
+			}
+			notes = append(notes, parsed...)
+		}
+	}
+	return notes, nil
+}
+
+func getAdjustedNote(text string) (string, int) {
+	if strings.HasPrefix(text, "/*") {
+		text = strings.TrimSuffix(text, "*/")
+	}
+	text = text[2:] // remove "//" or "/*" prefix
+
+	// Allow notes to appear within comments.
+	// For example:
+	// "// //@mark()" is valid.
+	// "// @mark()" is not valid.
+	// "// /*@mark()*/" is not valid.
+	var adjust int
+	if i := strings.Index(text, commentStart); i > 2 {
+		// Get the text before the commentStart.
+		pre := text[i-2 : i]
+		if pre != "//" {
+			return "", 0
+		}
+		text = text[i:]
+		adjust = i
+	}
+	if !strings.HasPrefix(text, commentStart) {
+		return "", 0
+	}
+	text = text[commentStartLen:]
+	return text, commentStartLen + adjust + 1
+}
+
 const invalidToken rune = 0
 
 type tokens struct {
diff --git a/go/expect/testdata/go.mod b/go/expect/testdata/go.mod
new file mode 100644
index 0000000..01a73f9
--- /dev/null
+++ b/go/expect/testdata/go.mod
@@ -0,0 +1,7 @@
+module αfake1α //@mark(αMarker, "αfake1α")
+
+go 1.14
+
+require golang.org/modfile v0.0.0 //@mark(βMarker, "require golang.org/modfile v0.0.0")
+//@mark(IndirectMarker, "// indirect")
+require golang.org/x/tools v0.0.0-20191219192050-56b0b28a00f7 // indirect
\ No newline at end of file
diff --git a/go/ssa/source_test.go b/go/ssa/source_test.go
index 9dc3c66..24cf57e 100644
--- a/go/ssa/source_test.go
+++ b/go/ssa/source_test.go
@@ -50,7 +50,7 @@
 	// Each note of the form @ssa(x, "BinOp") in testdata/objlookup.go
 	// specifies an expectation that an object named x declared on the
 	// same line is associated with an an ssa.Value of type *ssa.BinOp.
-	notes, err := expect.Extract(conf.Fset, f)
+	notes, err := expect.ExtractGo(conf.Fset, f)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -271,7 +271,7 @@
 		return true
 	})
 
-	notes, err := expect.Extract(prog.Fset, f)
+	notes, err := expect.ExtractGo(prog.Fset, f)
 	if err != nil {
 		t.Fatal(err)
 	}