internal/{cveutils,worker}: move CVE triage to cveutils

Move existing logic to triage v4 CVEs to the cveutils package.
This will make it easier to add tests and implement triage for v5 CVEs.

Change-Id: I4872af391a33500dd7236795a910ad3a6998b5e0
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/550857
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/cmd/worker/main.go b/cmd/worker/main.go
index adbf91b..db173aa 100644
--- a/cmd/worker/main.go
+++ b/cmd/worker/main.go
@@ -21,6 +21,7 @@
 	"time"
 
 	"golang.org/x/vulndb/internal/cvelistrepo"
+	"golang.org/x/vulndb/internal/cveutils"
 	"golang.org/x/vulndb/internal/ghsa"
 	"golang.org/x/vulndb/internal/gitrepo"
 	"golang.org/x/vulndb/internal/issues"
@@ -237,7 +238,7 @@
 	if err := scan.Err(); err != nil {
 		return err
 	}
-	worker.SetKnownModules(mods)
+	cveutils.SetKnownModules(mods)
 	fmt.Printf("set %d known modules\n", len(mods))
 	return nil
 }
diff --git a/internal/worker/paths.go b/internal/cveutils/paths.go
similarity index 99%
rename from internal/worker/paths.go
rename to internal/cveutils/paths.go
index 0840053..006de1e 100644
--- a/internal/worker/paths.go
+++ b/internal/cveutils/paths.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package worker
+package cveutils
 
 import (
 	"path"
diff --git a/internal/worker/paths_test.go b/internal/cveutils/paths_test.go
similarity index 98%
rename from internal/worker/paths_test.go
rename to internal/cveutils/paths_test.go
index ce027e8..0bfec14 100644
--- a/internal/worker/paths_test.go
+++ b/internal/cveutils/paths_test.go
@@ -5,7 +5,7 @@
 //go:build go1.17
 // +build go1.17
 
-package worker
+package cveutils
 
 import (
 	"testing"
diff --git a/internal/cveutils/pkgsite.go b/internal/cveutils/pkgsite.go
new file mode 100644
index 0000000..993d78f
--- /dev/null
+++ b/internal/cveutils/pkgsite.go
@@ -0,0 +1,99 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cveutils
+
+import (
+	"context"
+	"net/http"
+	"net/http/httptest"
+	"strconv"
+	"strings"
+	"testing"
+	"time"
+
+	"golang.org/x/time/rate"
+	"golang.org/x/vulndb/internal/worker/log"
+)
+
+// Limit pkgsite requests to this many per second.
+const pkgsiteQPS = 5
+
+var (
+	// The limiter used to throttle pkgsite requests.
+	// The second argument to rate.NewLimiter is the burst, which
+	// basically lets you exceed the rate briefly.
+	pkgsiteRateLimiter = rate.NewLimiter(rate.Every(time.Duration(1000/float64(pkgsiteQPS))*time.Millisecond), 3)
+
+	// Cache of module paths already seen.
+	seenModulePath = map[string]bool{}
+	// Does seenModulePath contain all known modules?
+	cacheComplete = false
+)
+
+// SetKnownModules provides a list of all known modules,
+// so that no requests need to be made to pkg.go.dev.
+func SetKnownModules(mods []string) {
+	for _, m := range mods {
+		seenModulePath[m] = true
+	}
+	cacheComplete = true
+}
+
+var pkgsiteURL = "https://pkg.go.dev"
+
+// knownToPkgsite reports whether pkgsite knows that modulePath actually refers
+// to a module.
+func knownToPkgsite(ctx context.Context, baseURL, modulePath string) (bool, error) {
+	// If we've seen it before, no need to call.
+	if b, ok := seenModulePath[modulePath]; ok {
+		return b, nil
+	}
+	if cacheComplete {
+		return false, nil
+	}
+	// Pause to maintain a max QPS.
+	if err := pkgsiteRateLimiter.Wait(ctx); err != nil {
+		return false, err
+	}
+	start := time.Now()
+
+	url := baseURL + "/mod/" + modulePath
+	res, err := http.Head(url)
+	var status string
+	if err == nil {
+		status = strconv.Quote(res.Status)
+	}
+	log.With(
+		"latency", time.Since(start),
+		"status", status,
+		"error", err,
+	).Debugf(ctx, "checked if %s is known to pkgsite at HEAD", url)
+	if err != nil {
+		return false, err
+	}
+	known := res.StatusCode == http.StatusOK
+	seenModulePath[modulePath] = known
+	return known, nil
+}
+
+// GetPkgsiteURL returns a URL to either a fake server or the real pkg.go.dev,
+// depending on the useRealPkgsite value.
+//
+// For testing.
+func GetPkgsiteURL(t *testing.T, useRealPkgsite bool) string {
+	if useRealPkgsite {
+		return pkgsiteURL
+	}
+	// Start a test server that recognizes anything from golang.org and bitbucket.org/foo/bar/baz.
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		modulePath := strings.TrimPrefix(r.URL.Path, "/mod/")
+		if !strings.HasPrefix(modulePath, "golang.org/") &&
+			!strings.HasPrefix(modulePath, "bitbucket.org/foo/bar/baz") {
+			http.Error(w, "unknown", http.StatusNotFound)
+		}
+	}))
+	t.Cleanup(s.Close)
+	return s.URL
+}
diff --git a/internal/worker/triage.go b/internal/cveutils/triage.go
similarity index 61%
rename from internal/worker/triage.go
rename to internal/cveutils/triage.go
index 7309914..a47f782 100644
--- a/internal/worker/triage.go
+++ b/internal/cveutils/triage.go
@@ -2,20 +2,16 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package worker
+package cveutils
 
 import (
 	"context"
 	"errors"
 	"fmt"
-	"net/http"
 	"net/url"
 	"regexp"
-	"strconv"
 	"strings"
-	"time"
 
-	"golang.org/x/time/rate"
 	"golang.org/x/vulndb/internal/cveschema"
 	"golang.org/x/vulndb/internal/derrors"
 	"golang.org/x/vulndb/internal/ghsa"
@@ -40,7 +36,7 @@
 const unknownPath = "Path is unknown"
 
 // TriageCVE reports whether the CVE refers to a Go module.
-func TriageCVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ *triageResult, err error) {
+func TriageCVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ *TriageResult, err error) {
 	defer derrors.Wrap(&err, "triageCVE(%q)", c.ID)
 	switch c.DataVersion {
 	case "4.0":
@@ -51,10 +47,10 @@
 	}
 }
 
-type triageResult struct {
-	modulePath  string
-	packagePath string
-	reason      string
+type TriageResult struct {
+	ModulePath  string
+	PackagePath string
+	Reason      string
 }
 
 // gopkgHosts are hostnames for popular Go package websites.
@@ -83,7 +79,7 @@
 }
 
 // triageV4CVE triages a CVE following schema v4.0 and returns the result.
-func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (result *triageResult, err error) {
+func triageV4CVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (result *TriageResult, err error) {
 	defer derrors.Wrap(&err, "triageV4CVE(ctx, %q, %q)", c.ID, pkgsiteURL)
 	defer func() {
 		if err != nil {
@@ -94,7 +90,7 @@
 			log.Debugf(ctx, "%s: not Go vuln", msg)
 			return
 		}
-		log.Debugf(ctx, "%s: is Go vuln:\n%s", msg, result.reason)
+		log.Debugf(ctx, "%s: is Go vuln:\n%s", msg, result.Reason)
 	}()
 	for _, r := range c.References.Data {
 		if r.URL == "" {
@@ -106,24 +102,24 @@
 		}
 		if strings.Contains(r.URL, "golang.org/pkg") {
 			mp := strings.TrimPrefix(refURL.Path, "/pkg/")
-			return &triageResult{
-				packagePath: mp,
-				modulePath:  stdlib.ModulePath,
-				reason:      fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp),
+			return &TriageResult{
+				PackagePath: mp,
+				ModulePath:  stdlib.ModulePath,
+				Reason:      fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp),
 			}, nil
 		}
 		if gopkgHosts[refURL.Host] {
 			mp := strings.TrimPrefix(refURL.Path, "/")
 			if stdlib.Contains(mp) {
-				return &triageResult{
-					packagePath: mp,
-					modulePath:  stdlib.ModulePath,
-					reason:      fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp),
+				return &TriageResult{
+					PackagePath: mp,
+					ModulePath:  stdlib.ModulePath,
+					Reason:      fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp),
 				}, nil
 			}
-			return &triageResult{
-				modulePath: mp,
-				reason:     fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp),
+			return &TriageResult{
+				ModulePath: mp,
+				Reason:     fmt.Sprintf("Reference data URL %q contains path %q", r.URL, mp),
 			}, nil
 		}
 		modpaths := candidateModulePaths(refURL.Host + refURL.Path)
@@ -137,9 +133,9 @@
 			}
 			if known {
 				u := pkgsiteURL + "/" + mp
-				return &triageResult{
-					modulePath: mp,
-					reason:     fmt.Sprintf("Reference data URL %q contains path %q; %q returned a status 200", r.URL, mp, u),
+				return &TriageResult{
+					ModulePath: mp,
+					Reason:     fmt.Sprintf("Reference data URL %q contains path %q; %q returned a status 200", r.URL, mp, u),
 				}, nil
 			}
 		}
@@ -151,9 +147,9 @@
 		// Example CVE containing snyk.io URL:
 		// https://github.com/CVEProject/cvelist/blob/899bba20d62eb73e04d1841a5ff04cd6225e1618/2020/7xxx/CVE-2020-7668.json#L52.
 		if strings.Contains(r.URL, snykIdentifier) {
-			return &triageResult{
-				modulePath: unknownPath,
-				reason:     fmt.Sprintf("Reference data URL %q contains %q", r.URL, snykIdentifier),
+			return &TriageResult{
+				ModulePath: unknownPath,
+				Reason:     fmt.Sprintf("Reference data URL %q contains %q", r.URL, snykIdentifier),
 			}, nil
 		}
 
@@ -161,9 +157,9 @@
 		// project.
 		for _, k := range stdlibReferenceDataKeywords {
 			if strings.Contains(r.URL, k) {
-				return &triageResult{
-					modulePath: stdlib.ModulePath,
-					reason:     fmt.Sprintf("Reference data URL %q contains %q", r.URL, k),
+				return &TriageResult{
+					ModulePath: stdlib.ModulePath,
+					Reason:     fmt.Sprintf("Reference data URL %q contains %q", r.URL, k),
 				}, nil
 			}
 		}
@@ -173,69 +169,10 @@
 
 var ghsaRegex = regexp.MustCompile(ghsa.Regex)
 
-func getAliasGHSAs(c *cveschema.CVE) []string {
+func GetAliasGHSAs(c *cveschema.CVE) []string {
 	var ghsas []string
 	for _, r := range c.References.Data {
 		ghsas = append(ghsas, ghsaRegex.FindAllString(r.URL, 1)...)
 	}
 	return ghsas
 }
-
-// Limit pkgsite requests to this many per second.
-const pkgsiteQPS = 5
-
-var (
-	// The limiter used to throttle pkgsite requests.
-	// The second argument to rate.NewLimiter is the burst, which
-	// basically lets you exceed the rate briefly.
-	pkgsiteRateLimiter = rate.NewLimiter(rate.Every(time.Duration(1000/float64(pkgsiteQPS))*time.Millisecond), 3)
-
-	// Cache of module paths already seen.
-	seenModulePath = map[string]bool{}
-	// Does seenModulePath contain all known modules?
-	cacheComplete = false
-)
-
-// SetKnownModules provides a list of all known modules,
-// so that no requests need to be made to pkg.go.dev.
-func SetKnownModules(mods []string) {
-	for _, m := range mods {
-		seenModulePath[m] = true
-	}
-	cacheComplete = true
-}
-
-// knownToPkgsite reports whether pkgsite knows that modulePath actually refers
-// to a module.
-func knownToPkgsite(ctx context.Context, baseURL, modulePath string) (bool, error) {
-	// If we've seen it before, no need to call.
-	if b, ok := seenModulePath[modulePath]; ok {
-		return b, nil
-	}
-	if cacheComplete {
-		return false, nil
-	}
-	// Pause to maintain a max QPS.
-	if err := pkgsiteRateLimiter.Wait(ctx); err != nil {
-		return false, err
-	}
-	start := time.Now()
-
-	url := baseURL + "/mod/" + modulePath
-	res, err := http.Head(url)
-	var status string
-	if err == nil {
-		status = strconv.Quote(res.Status)
-	}
-	log.With(
-		"latency", time.Since(start),
-		"status", status,
-		"error", err,
-	).Debugf(ctx, "checked if %s is known to pkgsite at HEAD", url)
-	if err != nil {
-		return false, err
-	}
-	known := res.StatusCode == http.StatusOK
-	seenModulePath[modulePath] = known
-	return known, nil
-}
diff --git a/internal/worker/triage_test.go b/internal/cveutils/triage_test.go
similarity index 73%
rename from internal/worker/triage_test.go
rename to internal/cveutils/triage_test.go
index 52e10cb..d8da5c5 100644
--- a/internal/worker/triage_test.go
+++ b/internal/cveutils/triage_test.go
@@ -5,14 +5,11 @@
 //go:build go1.17
 // +build go1.17
 
-package worker
+package cveutils
 
 import (
 	"context"
 	"flag"
-	"net/http"
-	"net/http/httptest"
-	"strings"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -25,12 +22,12 @@
 
 func TestTriageV4CVE(t *testing.T) {
 	ctx := context.Background()
-	url := getPkgsiteURL(t)
+	url := GetPkgsiteURL(t, *usePkgsite)
 
 	for _, test := range []struct {
 		name string
 		in   *cveschema.CVE
-		want *triageResult
+		want *TriageResult
 	}{
 		{
 			"repo path is unknown Go standard library",
@@ -41,8 +38,8 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath: stdlib.ModulePath,
+			&TriageResult{
+				ModulePath: stdlib.ModulePath,
 			},
 		},
 		{
@@ -54,9 +51,9 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath:  stdlib.ModulePath,
-				packagePath: "net/http",
+			&TriageResult{
+				ModulePath:  stdlib.ModulePath,
+				PackagePath: "net/http",
 			},
 		},
 		{
@@ -69,8 +66,8 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath: "golang.org/x/mod",
+			&TriageResult{
+				ModulePath: "golang.org/x/mod",
 			},
 		},
 		{
@@ -82,8 +79,8 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath: "golang.org/x/mod",
+			&TriageResult{
+				ModulePath: "golang.org/x/mod",
 			},
 		},
 		{
@@ -95,9 +92,9 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath:  stdlib.ModulePath,
-				packagePath: "net/http",
+			&TriageResult{
+				ModulePath:  stdlib.ModulePath,
+				PackagePath: "net/http",
 			},
 		},
 		{
@@ -120,8 +117,8 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath: "golang.org/x/exp/event",
+			&TriageResult{
+				ModulePath: "golang.org/x/exp/event",
 			},
 		},
 		{
@@ -144,8 +141,8 @@
 					},
 				},
 			},
-			&triageResult{
-				modulePath: unknownPath,
+			&TriageResult{
+				ModulePath: unknownPath,
 			},
 		},
 	} {
@@ -156,8 +153,8 @@
 				t.Fatal(err)
 			}
 			if diff := cmp.Diff(test.want, got,
-				cmp.AllowUnexported(triageResult{}),
-				cmpopts.IgnoreFields(triageResult{}, "reason")); diff != "" {
+				cmp.AllowUnexported(TriageResult{}),
+				cmpopts.IgnoreFields(TriageResult{}, "Reason")); diff != "" {
 				t.Errorf("mismatch (-want, +got):\n%s", diff)
 			}
 		})
@@ -168,7 +165,7 @@
 	ctx := context.Background()
 
 	const validModule = "golang.org/x/mod"
-	url := getPkgsiteURL(t)
+	url := GetPkgsiteURL(t, *usePkgsite)
 
 	for _, test := range []struct {
 		in   string
@@ -199,25 +196,7 @@
 		},
 	}
 	want := "GHSA-xxxx-yyyy-0000"
-	if got := getAliasGHSAs(cve); got[0] != want {
+	if got := GetAliasGHSAs(cve); got[0] != want {
 		t.Errorf("getAliasGHSAs: got %s, want %s", got, want)
 	}
 }
-
-// getPkgsiteURL returns a URL to either a fake server or the real pkg.go.dev,
-// depending on the usePkgsite flag.
-func getPkgsiteURL(t *testing.T) string {
-	if *usePkgsite {
-		return pkgsiteURL
-	}
-	// Start a test server that recognizes anything from golang.org and bitbucket.org/foo/bar/baz.
-	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		modulePath := strings.TrimPrefix(r.URL.Path, "/mod/")
-		if !strings.HasPrefix(modulePath, "golang.org/") &&
-			!strings.HasPrefix(modulePath, "bitbucket.org/foo/bar/baz") {
-			http.Error(w, "unknown", http.StatusNotFound)
-		}
-	}))
-	t.Cleanup(s.Close)
-	return s.URL
-}
diff --git a/internal/worker/update.go b/internal/worker/update.go
index 9cbe4d5..05467fe 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -16,6 +16,7 @@
 	"github.com/go-git/go-git/v5/plumbing/object"
 	"golang.org/x/vulndb/internal/cvelistrepo"
 	"golang.org/x/vulndb/internal/cveschema"
+	"golang.org/x/vulndb/internal/cveutils"
 	"golang.org/x/vulndb/internal/derrors"
 	"golang.org/x/vulndb/internal/ghsa"
 	"golang.org/x/vulndb/internal/observe"
@@ -26,7 +27,7 @@
 // A triageFunc triages a CVE: it decides whether an issue needs to be filed.
 // If so, it returns a non-empty string indicating the possibly
 // affected module.
-type triageFunc func(*cveschema.CVE) (*triageResult, error)
+type triageFunc func(*cveschema.CVE) (*cveutils.TriageResult, error)
 
 // A cveUpdater performs an update operation on the DB.
 type cveUpdater struct {
@@ -260,7 +261,7 @@
 // worker has already handled, and returns the appropriate triage state
 // based on this.
 func checkForAliases(cve *cveschema.CVE, tx store.Transaction) (store.TriageState, error) {
-	for _, ghsaID := range getAliasGHSAs(cve) {
+	for _, ghsaID := range cveutils.GetAliasGHSAs(cve) {
 		ghsa, err := tx.GetGHSARecord(ghsaID)
 		if err != nil {
 			return "", err
@@ -282,7 +283,7 @@
 	if err := cvelistrepo.Parse(u.repo, f, cve); err != nil {
 		return nil, false, err
 	}
-	var result *triageResult
+	var result *cveutils.TriageResult
 	if cve.State == cveschema.StatePublic && !u.knownIDs[cve.ID] {
 		c := cve
 		// If a false positive has changed, we only care about
@@ -309,9 +310,9 @@
 				return nil, false, err
 			}
 			cr.TriageState = triageState
-			cr.Module = result.modulePath
-			cr.Package = result.packagePath
-			cr.TriageStateReason = result.reason
+			cr.Module = result.ModulePath
+			cr.Package = result.PackagePath
+			cr.TriageStateReason = result.Reason
 			cr.CVE = cve
 		case u.knownIDs[cve.ID]:
 			cr.TriageState = store.TriageStateHasVuln
@@ -332,9 +333,9 @@
 		if result != nil {
 			// Didn't need an issue before, does now.
 			mod.TriageState = store.TriageStateNeedsIssue
-			mod.Module = result.modulePath
-			mod.Package = result.packagePath
-			mod.TriageStateReason = result.reason
+			mod.Module = result.ModulePath
+			mod.Package = result.PackagePath
+			mod.TriageStateReason = result.Reason
 			mod.CVE = cve
 		}
 		// Else don't change the triage state, but we still want
@@ -355,7 +356,7 @@
 		mod.TriageState = store.TriageStateUpdatedSinceIssueCreation
 		var mp string
 		if result != nil {
-			mp = result.modulePath
+			mp = result.ModulePath
 		}
 		mod.TriageStateReason = fmt.Sprintf("CVE changed; affected module = %q", mp)
 	case store.TriageStateAlias:
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index 9490ef9..63d0f4b 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -9,6 +9,7 @@
 
 import (
 	"context"
+	"flag"
 	"testing"
 	"time"
 
@@ -18,11 +19,14 @@
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"golang.org/x/vulndb/internal/cvelistrepo"
 	"golang.org/x/vulndb/internal/cveschema"
+	"golang.org/x/vulndb/internal/cveutils"
 	"golang.org/x/vulndb/internal/ghsa"
 	"golang.org/x/vulndb/internal/gitrepo"
 	"golang.org/x/vulndb/internal/worker/store"
 )
 
+var usePkgsite = flag.Bool("pkgsite", false, "use pkg.go.dev for tests")
+
 const clearString = "**CLEAR**"
 
 var clearCVE = &cveschema.CVE{}
@@ -90,9 +94,9 @@
 		t.Fatal(err)
 	}
 	commit := headCommit(t, repo)
-	purl := getPkgsiteURL(t)
-	needsIssue := func(cve *cveschema.CVE) (*triageResult, error) {
-		return TriageCVE(ctx, cve, purl)
+	purl := cveutils.GetPkgsiteURL(t, *usePkgsite)
+	needsIssue := func(cve *cveschema.CVE) (*cveutils.TriageResult, error) {
+		return cveutils.TriageCVE(ctx, cve, purl)
 	}
 
 	commitHash := commit.Hash.String()
diff --git a/internal/worker/worker.go b/internal/worker/worker.go
index c79d8dc..202d5d0 100644
--- a/internal/worker/worker.go
+++ b/internal/worker/worker.go
@@ -22,6 +22,7 @@
 	"golang.org/x/time/rate"
 	"golang.org/x/vulndb/internal/cvelistrepo"
 	"golang.org/x/vulndb/internal/cveschema"
+	"golang.org/x/vulndb/internal/cveutils"
 	"golang.org/x/vulndb/internal/derrors"
 	"golang.org/x/vulndb/internal/ghsa"
 	"golang.org/x/vulndb/internal/gitrepo"
@@ -74,8 +75,8 @@
 	if err != nil {
 		return err
 	}
-	u := newCVEUpdater(repo, commit, st, knownVulnIDs, func(cve *cveschema.CVE) (*triageResult, error) {
-		return TriageCVE(ctx, cve, pkgsiteURL)
+	u := newCVEUpdater(repo, commit, st, knownVulnIDs, func(cve *cveschema.CVE) (*cveutils.TriageResult, error) {
+		return cveutils.TriageCVE(ctx, cve, pkgsiteURL)
 	})
 	_, err = u.update(ctx)
 	return err