cmd/vulnreport: automatically fetch modules in symbol checks

The "vulnreport fix" command loads affected modules to add exported
symbols. (Eventually, this command should also validate the listed
symbols exist in the module, but it doesn't do that right now.)

Running fix requires some manual effort on the user's part.
From the triage documentation:

  mkdir /tmp/mymod
  cd /tmp/mymod
  go mod init
  go get github.com/my/mod@<version-before-fixed>
  go run <path to /cmd/vulnreport> fix

Automate this.

Detemining the "version-before-fixed" is programmatically is difficult.
Rather than trying to do so, add a "vulnerable_at" field to reports
which specifies a known-vulnerable version to use. Placing this in
the report also makes whatever work fix does more reproducable, since
we have an audit trail of what version was used.

Change-Id: Ie76d582a1f5192597f411b60eb407c2c014a9d35
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/412395
Reviewed-by: Julie Qiu <julieqiu@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/cmd/vulnreport/exported_functions.go b/cmd/vulnreport/exported_functions.go
index 7c335c9..570c374 100644
--- a/cmd/vulnreport/exported_functions.go
+++ b/cmd/vulnreport/exported_functions.go
@@ -6,7 +6,6 @@
 
 import (
 	"context"
-	"errors"
 	"fmt"
 	"os"
 	"strings"
@@ -47,10 +46,7 @@
 func exportedFunctions(pkgs []*packages.Package, rc *reportClient) (_ map[string]bool, err error) {
 	defer derrors.Wrap(&err, "exportedFunctions(%q)", pkgs[0].PkgPath)
 
-	if pkgs[0].Module == nil {
-		return nil, errors.New("pkgs[0] is missing Module")
-	}
-	if !affected(rc.entry, pkgs[0].Module.Version) {
+	if pkgs[0].Module != nil && !affected(rc.entry, pkgs[0].Module.Version) {
 		fmt.Fprintf(os.Stderr, "version %s of module %s is not affected by this vuln\n",
 			pkgs[0].Module.Version, pkgs[0].Module.Path)
 		return map[string]bool{}, nil
diff --git a/cmd/vulnreport/main.go b/cmd/vulnreport/main.go
index 9e57331..e0eec04 100644
--- a/cmd/vulnreport/main.go
+++ b/cmd/vulnreport/main.go
@@ -15,7 +15,9 @@
 	"go/build"
 	"log"
 	"os"
+	"os/exec"
 	"regexp"
+	"runtime"
 	"sort"
 	"strconv"
 	"strings"
@@ -33,11 +35,11 @@
 )
 
 var (
-	localRepoPath       = flag.String("local-cve-repo", "", "path to local repo, instead of cloning remote")
-	issueRepo           = flag.String("issue-repo", "github.com/golang/vulndb", "repo to create issues in")
-	githubToken         = flag.String("ghtoken", os.Getenv("VULN_GITHUB_ACCESS_TOKEN"), "GitHub access token")
-	skipExportedSymbols = flag.Bool("skip-exported", false, "for fix, don't look for exported symbols")
-	alwaysFixGHSA       = flag.Bool("always-fix-ghsa", false, "for fix, always update GHSAs")
+	localRepoPath = flag.String("local-cve-repo", "", "path to local repo, instead of cloning remote")
+	issueRepo     = flag.String("issue-repo", "github.com/golang/vulndb", "repo to create issues in")
+	githubToken   = flag.String("ghtoken", os.Getenv("VULN_GITHUB_ACCESS_TOKEN"), "GitHub access token")
+	skipSymbols   = flag.Bool("skip-symbols", false, "for lint and fix, don't load package for symbols checks")
+	alwaysFixGHSA = flag.Bool("always-fix-ghsa", false, "for fix, always update GHSAs")
 )
 
 func main() {
@@ -249,8 +251,8 @@
 	if lints := r.Lint(); len(lints) > 0 {
 		r.Fix()
 	}
-	if !*skipExportedSymbols {
-		if _, err := addExportedReportSymbols(r); err != nil {
+	if !*skipSymbols {
+		if _, err := checkReportSymbols(r); err != nil {
 			return err
 		}
 	}
@@ -262,7 +264,7 @@
 	return r.Write(filename)
 }
 
-func addExportedReportSymbols(r *report.Report) (bool, error) {
+func checkReportSymbols(r *report.Report) (bool, error) {
 	if len(r.OS) > 0 || len(r.Arch) > 0 {
 		return false, errors.New("specific GOOS/GOARCH not yet implemented")
 	}
@@ -272,7 +274,7 @@
 		if len(p.Symbols) == 0 {
 			continue
 		}
-		syms, err := findExportedSymbols(p.Module, p.Package, rc)
+		syms, err := findExportedSymbols(p, rc)
 		if err != nil {
 			return false, err
 		}
@@ -285,12 +287,47 @@
 	return added, nil
 }
 
-func findExportedSymbols(module, pkgPath string, c *reportClient) (_ []string, err error) {
-	defer derrors.Wrap(&err, "addExportedSymbols(%q, %q)", module, pkgPath)
+func findExportedSymbols(p report.Package, c *reportClient) (_ []string, err error) {
+	defer derrors.Wrap(&err, "addExportedSymbols(%q, %q)", p.Module, p.Package)
 
+	if p.VulnerableAt == "" {
+		fmt.Fprintf(os.Stderr, "%v: no vulnerable_at version, skipping symbol checks.\n", p.Package)
+		return nil, nil
+	}
+
+	module := p.Module
+	pkgPath := p.Package
 	if pkgPath == "" {
 		pkgPath = module
 	}
+
+	cleanup, err := changeToTempDir()
+	if err != nil {
+		return nil, err
+	}
+	defer cleanup()
+	if err := run("go", "mod", "init", "go.dev/_"); err != nil {
+		return nil, err
+	}
+	std := false
+	if !stdlib.Contains(p.Module) {
+		pkgPathAndVersion := pkgPath + "@" + p.VulnerableAt.V()
+		if err := run("go", "get", pkgPathAndVersion); err != nil {
+			return nil, err
+		}
+	} else {
+		std = true
+		gover := runtime.Version()
+		ver := semverForGoVersion(gover)
+		if ver == "" || !affected(c.entry, ver.V()) {
+			fmt.Fprintf(os.Stderr, "%v: Go version %q is not in a vulnerable range, skipping symbol checks.\n", pkgPath, gover)
+			return nil, nil
+		}
+		if ver != p.VulnerableAt {
+			fmt.Fprintf(os.Stderr, "%v: WARNING: Go version %q does not match vulnerable_at version %q.\n", pkgPath, ver, p.VulnerableAt)
+		}
+	}
+
 	pkgs, err := loadPackage(&packages.Config{}, pkgPath)
 	if err != nil {
 		return nil, err
@@ -302,8 +339,14 @@
 	if pkgs[0].PkgPath != pkgPath {
 		return nil, fmt.Errorf("first package had import path %s, wanted %s", pkgs[0].PkgPath, pkgPath)
 	}
-	if pm := pkgs[0].Module; pm == nil || pm.Path != module {
-		return nil, fmt.Errorf("got module %v, expected %s", pm, module)
+	if std {
+		if pm := pkgs[0].Module; std && pm != nil {
+			return nil, fmt.Errorf("got module %v, expected nil", pm)
+		}
+	} else {
+		if pm := pkgs[0].Module; pm == nil || pm.Path != module {
+			return nil, fmt.Errorf("got module %v, expected %s", pm, module)
+		}
 	}
 	newsyms, err := exportedFunctions(pkgs, c)
 	if err != nil {
@@ -418,6 +461,35 @@
 	return pkgs, nil
 }
 
+func changeToTempDir() (cleanup func(), _ error) {
+	cwd, err := os.Getwd()
+	if err != nil {
+		return nil, err
+	}
+	dir, err := os.MkdirTemp("", "vulnreport")
+	if err != nil {
+		return nil, err
+	}
+	cleanup = func() {
+		os.Chdir(cwd)
+		os.RemoveAll(dir)
+	}
+	if err := os.Chdir(dir); err != nil {
+		cleanup()
+		return nil, err
+	}
+	return cleanup, err
+}
+
+func run(name string, arg ...string) error {
+	cmd := exec.Command(name, arg...)
+	out, err := cmd.CombinedOutput()
+	if err != nil {
+		os.Stderr.Write(out)
+	}
+	return err
+}
+
 // setDates sets the PublishedDate of the report at filename to the oldest
 // commit date in the repo that contains that file. (It may someday also set a
 // last-modified date, hence the plural.) Since it looks at the commits from
diff --git a/internal/report/lint.go b/internal/report/lint.go
index 17a64c4..21769cf 100644
--- a/internal/report/lint.go
+++ b/internal/report/lint.go
@@ -189,6 +189,9 @@
 }
 
 func (p *Package) lintVersions(addPkgIssue func(string)) {
+	if p.VulnerableAt != "" && !p.VulnerableAt.IsValid() {
+		addPkgIssue(fmt.Sprintf("invalid vulnerable_at semantic version: %q", p.VulnerableAt))
+	}
 	for i, vr := range p.Versions {
 		for _, v := range []Version{vr.Introduced, vr.Fixed} {
 			if v != "" && !v.IsValid() {
diff --git a/internal/report/report.go b/internal/report/report.go
index 9fcb211..5748edb 100644
--- a/internal/report/report.go
+++ b/internal/report/report.go
@@ -56,6 +56,13 @@
 	// or other technique.
 	DerivedSymbols []string       `yaml:"derived_symbols,omitempty"`
 	Versions       []VersionRange `yaml:",omitempty"`
+	// Known-vulnerable version, to use when performing static analysis or
+	// other techniques on a vulnerable version of the package.
+	//
+	// In general, we want to use the most recent vulnerable version of
+	// the package. Determining this programmatically is difficult, especially
+	// for packages without tagged versions, so we specify it manually here.
+	VulnerableAt Version `yaml:"vulnerable_at,omitempty"`
 }
 
 type Links struct {