cmd/vulnreport, internal/ghsa: speed up GHSA query

The "vulnreport fix" command queries GitHub for GHSAs related to the CVE
in each fixed report. Instead of querying for every Go GHSA (slow),
query for GHSAs related to the specific CVE(s) of interest.

Skip the GHSA query entirely if the GHSAs field is populated,
unless the -always-fix-ghsa flag is provided.

This may issue more queries if fixing many reports with no GHSAs,
but dramatically speeds up the query when fixing a single report
(from 30s to <1s). Also, the number of queries no longer scales with
the size of the GitHub GHSA corpus.

Change-Id: I0dde2ad7ebb4621785575c0dffdadd3febd873d4
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/404116
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
diff --git a/cmd/vulnreport/main.go b/cmd/vulnreport/main.go
index a594e95..daee923 100644
--- a/cmd/vulnreport/main.go
+++ b/cmd/vulnreport/main.go
@@ -35,6 +35,7 @@
 	issueRepo           = flag.String("issue-repo", "github.com/golang/vulndb", "repo to create issues in")
 	githubToken         = flag.String("ghtoken", os.Getenv("VULN_GITHUB_ACCESS_TOKEN"), "GitHub access token")
 	skipExportedSymbols = flag.Bool("skip-exported", false, "for fix, don't look for exported symbols")
+	alwaysFixGHSA       = flag.Bool("always-fix-ghsa", false, "for fix, always update GHSAs")
 )
 
 func main() {
@@ -86,19 +87,7 @@
 			log.Fatal(err)
 		}
 	case "fix":
-		var GHSAsByCVE map[string][]string
-		if *githubToken == "" {
-			fmt.Println("flag -ghtoken not provided, so not fixing GHSAs")
-		} else {
-			fmt.Println("querying GitHub for GHSAs...")
-			var err error
-			GHSAsByCVE, err = loadGHSAsByCVE(ctx, *githubToken)
-			if err != nil {
-				log.Fatal(err)
-			}
-			fmt.Println("fixing...")
-		}
-		f := func(name string) error { return fix(name, GHSAsByCVE) }
+		f := func(name string) error { return fix(ctx, name, *githubToken) }
 		if err := multi(f, names); err != nil {
 			log.Fatal(err)
 		}
@@ -225,7 +214,7 @@
 	return nil
 }
 
-func fix(filename string, GHSAsByCVE map[string][]string) (err error) {
+func fix(ctx context.Context, filename string, accessToken string) (err error) {
 	defer derrors.Wrap(&err, "fix(%q)", filename)
 	r, err := report.Read(filename)
 	if err != nil {
@@ -239,8 +228,8 @@
 			return err
 		}
 	}
-	if GHSAsByCVE != nil {
-		fixGHSAs(r, GHSAsByCVE)
+	if err := fixGHSAs(ctx, r, accessToken); err != nil {
+		return err
 	}
 
 	// Write unconditionally in order to format.
@@ -407,11 +396,28 @@
 
 // fixGHSAs replaces r.GHSAs with a sorted list of GitHub Security
 // Advisory IDs that correspond to the CVEs.
-func fixGHSAs(r *report.Report, GHSAsByCVE map[string][]string) {
-	var gids []string
+func fixGHSAs(ctx context.Context, r *report.Report, accessToken string) error {
+	if accessToken == "" {
+		return nil
+	}
+	if len(r.GHSAs) > 0 && !*alwaysFixGHSA {
+		return nil
+	}
+	m := map[string]struct{}{}
 	for _, cid := range r.CVEs {
-		gids = append(gids, GHSAsByCVE[cid]...)
+		sas, err := ghsa.ListForCVE(ctx, accessToken, cid)
+		if err != nil {
+			return err
+		}
+		for _, sa := range sas {
+			m[sa.PrettyID()] = struct{}{}
+		}
+	}
+	var gids []string
+	for gid := range m {
+		gids = append(gids, gid)
 	}
 	sort.Strings(gids)
 	r.GHSAs = gids
+	return nil
 }
diff --git a/internal/ghsa/ghsa.go b/internal/ghsa/ghsa.go
index 47fdf5d..7fb38df 100644
--- a/internal/ghsa/ghsa.go
+++ b/internal/ghsa/ghsa.go
@@ -191,6 +191,46 @@
 	return sas, nil
 }
 
+func ListForCVE(ctx context.Context, accessToken string, cve string) ([]*SecurityAdvisory, error) {
+	client := newGitHubClient(ctx, accessToken)
+
+	var query struct { // The GraphQL query
+		SAs struct {
+			Nodes    []gqlSecurityAdvisory
+			PageInfo struct {
+				EndCursor   githubv4.String
+				HasNextPage bool
+			}
+		} `graphql:"securityAdvisories(identifier: $id, first: 100)"`
+	}
+	vars := map[string]any{
+		"id": githubv4.SecurityAdvisoryIdentifierFilter{
+			Type:  githubv4.SecurityAdvisoryIdentifierTypeCve,
+			Value: githubv4.String(cve),
+		},
+		"go": githubv4.SecurityAdvisoryEcosystemGo,
+	}
+
+	if err := client.Query(ctx, &query, vars); err != nil {
+		return nil, err
+	}
+	if query.SAs.PageInfo.HasNextPage {
+		return nil, fmt.Errorf("CVE %s has more than 100 GHSAs", cve)
+	}
+	var sas []*SecurityAdvisory
+	for _, sa := range query.SAs.Nodes {
+		if len(sa.Vulnerabilities.Nodes) == 0 {
+			continue
+		}
+		s, err := sa.securityAdvisory()
+		if err != nil {
+			return nil, err
+		}
+		sas = append(sas, s)
+	}
+	return sas, nil
+}
+
 // FetchGHSA returns the SecurityAdvisory for the given Github Security
 // Advisory ID.
 func FetchGHSA(ctx context.Context, accessToken, ghsaID string) (_ *SecurityAdvisory, err error) {
diff --git a/internal/ghsa/ghsa_test.go b/internal/ghsa/ghsa_test.go
index 57af6fa..9652f66 100644
--- a/internal/ghsa/ghsa_test.go
+++ b/internal/ghsa/ghsa_test.go
@@ -67,3 +67,29 @@
 		t.Errorf("got GHSA with id %q, want %q", got.ID, want)
 	}
 }
+
+func TestListForCVE(t *testing.T) {
+	accessToken := mustGetAccessToken(t)
+	// Real CVE and GHSA.
+	const (
+		cveID  string = "CVE-2022-27191"
+		ghsaID string = "GHSA-8c26-wmh5-6g9v"
+	)
+	got, err := ListForCVE(context.Background(), accessToken, cveID)
+	if err != nil {
+		t.Fatal(err)
+	}
+	var ids []string
+	for _, sa := range got {
+		for _, id := range sa.Identifiers {
+			if id.Type != "GHSA" {
+				continue
+			}
+			ids = append(ids, id.Value)
+			if id.Value == ghsaID {
+				return
+			}
+		}
+	}
+	t.Errorf("got %v GHSAs %v, want %v", len(got), ids, ghsaID)
+}