internal/worker: update changed CVEs

Add logic to update the DB when a CVE in the repo changes.

Change-Id: Ic03ddaead4e016d45f3cc8c07fe3c1b4d61f7936
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/366874
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/worker/update.go b/internal/worker/update.go
index 63fd8f0..67832b7 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -100,15 +100,6 @@
 	return st.SetUpdateRecord(ctx, ur)
 }
 
-// Action performed by handleCVE.
-type action int
-
-const (
-	nothing action = iota
-	add
-	mod
-)
-
 func updateBatch(ctx context.Context, batch []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, needsIssue triageFunc) (numAdds, numMods int, err error) {
 	startID := idFromFilename(batch[0].filename)
 	endID := idFromFilename(batch[len(batch)-1].filename)
@@ -131,14 +122,18 @@
 		// Determine what needs to be added and modified.
 		for _, f := range batch {
 			id := idFromFilename(f.filename)
-			act, err := handleCVE(ctx, repo, f, idToRecord[id], commitHash, needsIssue, tx)
+			old := idToRecord[id]
+			if old != nil && old.BlobHash == f.hash.String() {
+				// No change; do nothing.
+				continue
+			}
+			added, err := handleCVE(ctx, repo, f, old, commitHash, needsIssue, tx)
 			if err != nil {
 				return err
 			}
-			switch act {
-			case add:
+			if added {
 				numAdds++
-			case mod:
+			} else {
 				numMods++
 			}
 		}
@@ -152,48 +147,77 @@
 }
 
 // handleCVE determines how to change the store for a single CVE.
-func handleCVE(ctx context.Context, repo *git.Repository, f repoFile, old *store.CVERecord, commitHash plumbing.Hash, needsIssue triageFunc, tx store.Transaction) (_ action, err error) {
+// The CVE will definitely be either added, if it's new, or modified, if it's
+// already in the DB.
+func handleCVE(ctx context.Context, repo *git.Repository, f repoFile, old *store.CVERecord, commitHash plumbing.Hash, needsIssue triageFunc, tx store.Transaction) (added bool, err error) {
 	defer derrors.Wrap(&err, "handleCVE(%s)", f.filename)
 
-	if old != nil && old.BlobHash == f.hash.String() {
-		// No change; do nothing.
-		return nothing, nil
-	}
 	// Read CVE from repo.
 	r, err := blobReader(repo, f.hash)
 	if err != nil {
-		return nothing, err
+		return false, err
 	}
+	pathname := path.Join(f.dirpath, f.filename)
 	cve := &cveschema.CVE{}
 	if err := json.NewDecoder(r).Decode(cve); err != nil {
-		log.Printf("ERROR decoding %s: %v", f.filename, err)
-		return nothing, nil
+		return false, err
+	}
+	needs := false
+	if cve.State == cveschema.StatePublic {
+		needs, err = needsIssue(cve)
+		if err != nil {
+			return false, err
+		}
 	}
 
 	// If the CVE is not in the database, add it.
 	if old == nil {
 		cr := store.NewCVERecord(cve, path.Join(f.dirpath, f.filename), f.hash.String())
 		cr.CommitHash = commitHash.String()
-		needs := false
-		if cve.State == cveschema.StatePublic {
-			needs, err = needsIssue(cve)
-			if err != nil {
-				return nothing, err
-			}
-		}
 		if needs {
 			cr.TriageState = store.TriageStateNeedsIssue
 		} else {
 			cr.TriageState = store.TriageStateNoActionNeeded
 		}
 		if err := tx.CreateCVERecord(cr); err != nil {
-			return nothing, err
+			return false, err
 		}
-		return add, nil
-	} else {
-		// TODO(golang/go#49733): handle changes to CVEs.
+		return true, nil
 	}
-	return nothing, nil
+	// Change to an existing record.
+	mod := *old // copy the old one
+	mod.Path = pathname
+	mod.BlobHash = f.hash.String()
+	mod.CVEState = cve.State
+	mod.CommitHash = commitHash.String()
+	switch old.TriageState {
+	case store.TriageStateNoActionNeeded:
+		if needs {
+			// Didn't need an issue before, does now.
+			mod.TriageState = store.TriageStateNeedsIssue
+		}
+		// Else don't change the triage state, but we still want
+		// to update the other changed fields.
+	case store.TriageStateNeedsIssue:
+		if !needs {
+			// Needed an issue, no longer does.
+			mod.TriageState = store.TriageStateNoActionNeeded
+		}
+		// Else don't change the triage state, but we still want
+		// to update the other changed fields.
+	case store.TriageStateIssueCreated, store.TriageStateUpdatedSinceIssueCreation:
+		// An issue was filed, so a person should revisit this CVE.
+		mod.TriageState = store.TriageStateUpdatedSinceIssueCreation
+		mod.TriageStateReason = fmt.Sprintf("CVE changed; needs issue = %t", needs)
+		// TODO(golang/go#49733): keep a history of the previous states and their commits.
+	default:
+		return false, fmt.Errorf("unknown TriageState: %q", old.TriageState)
+	}
+	// If we're here, then mod is a modification to the DB.
+	if err := tx.SetCVERecord(&mod); err != nil {
+		return false, err
+	}
+	return false, nil
 }
 
 type repoFile struct {
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index f68f91c..283c8cf 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -12,7 +12,6 @@
 
 	"github.com/go-git/go-git/v5"
 	"github.com/go-git/go-git/v5/plumbing"
-	"github.com/go-git/go-git/v5/plumbing/object"
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"golang.org/x/vuln/internal/cveschema"
@@ -32,6 +31,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
+
 	want := []repoFile{
 		{dirpath: "2021/0xxx", filename: "CVE-2021-0001.json"},
 		{dirpath: "2021/0xxx", filename: "CVE-2021-0010.json"},
@@ -49,7 +49,6 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	mstore := store.NewMemStore()
 	h, err := headHash(repo)
 	if err != nil {
 		t.Fatal(err)
@@ -57,39 +56,153 @@
 	needsIssue := func(cve *cveschema.CVE) (bool, error) {
 		return strings.HasSuffix(cve.ID, "0001"), nil
 	}
-	if err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
-		t.Fatal(err)
-	}
+
 	ref, err := repo.Reference(plumbing.HEAD, true)
 	if err != nil {
 		t.Fatal(err)
 	}
-	r1 := newTestCVERecord(t, repo, ref, "2021/0xxx/CVE-2021-0001.json", store.TriageStateNeedsIssue)
-	r10 := newTestCVERecord(t, repo, ref, "2021/0xxx/CVE-2021-0010.json", store.TriageStateNoActionNeeded)
-	r384 := newTestCVERecord(t, repo, ref, "2021/1xxx/CVE-2021-1384.json", store.TriageStateNoActionNeeded)
-	wantRecords := map[string]*store.CVERecord{
-		"CVE-2021-0001": r1,
-		"CVE-2021-0010": r10,
-		"CVE-2021-1384": r384,
+	commitHash := ref.Hash().String()
+	const (
+		path1 = "2021/0xxx/CVE-2021-0001.json"
+		path2 = "2021/0xxx/CVE-2021-0010.json"
+		path3 = "2021/1xxx/CVE-2021-1384.json"
+	)
+	cve1, bh1 := readCVE(t, repo, path1)
+	cve2, bh2 := readCVE(t, repo, path2)
+	cve3, bh3 := readCVE(t, repo, path3)
+
+	// CVERecords after the above CVEs are added to an empty DB.
+	rs := []*store.CVERecord{
+		{
+			ID:          cve1.ID,
+			CVEState:    cve1.State,
+			Path:        path1,
+			BlobHash:    bh1,
+			CommitHash:  commitHash,
+			TriageState: store.TriageStateNeedsIssue, // a public CVE, needsIssue returns true
+		},
+		{
+			ID:          cve2.ID,
+			CVEState:    cve2.State,
+			Path:        path2,
+			BlobHash:    bh2,
+			CommitHash:  commitHash,
+			TriageState: store.TriageStateNoActionNeeded, // state is reserved
+		},
+		{
+			ID:          cve3.ID,
+			CVEState:    cve3.State,
+			Path:        path3,
+			BlobHash:    bh3,
+			CommitHash:  commitHash,
+			TriageState: store.TriageStateNoActionNeeded, // state is rejected
+		},
 	}
-	diff := cmp.Diff(wantRecords, mstore.CVERecords())
-	if diff != "" {
-		t.Errorf("mismatch (-want, +got):\n%s", diff)
+
+	// withTriageState returns a copy of r with the TriageState field changed to ts.
+	withTriageState := func(r *store.CVERecord, ts store.TriageState) *store.CVERecord {
+		c := *r
+		c.BlobHash += "x" // if we don't use a different blob hash, no update will happen
+		c.CommitHash = "?"
+		c.TriageState = ts
+		return &c
+	}
+
+	for _, test := range []struct {
+		name string
+		cur  []*store.CVERecord // current state of DB
+		want []*store.CVERecord // expected state after update
+	}{
+		{
+			name: "empty",
+			cur:  nil,
+			want: rs,
+		},
+		{
+			name: "no change",
+			cur:  rs,
+			want: rs,
+		},
+		{
+			name: "pre-issue changes",
+			cur: []*store.CVERecord{
+				// NoActionNeeded -> NeedsIssue
+				withTriageState(rs[0], store.TriageStateNoActionNeeded),
+				// NeedsIssue -> NoActionNeeded
+				withTriageState(rs[1], store.TriageStateNeedsIssue),
+				// NoActionNeeded, triage state stays the same but other fields change.
+				withTriageState(rs[2], store.TriageStateNoActionNeeded),
+			},
+			want: rs,
+		},
+		{
+			name: "post-issue changes",
+			cur: []*store.CVERecord{
+				// IssueCreated -> Updated
+				withTriageState(rs[0], store.TriageStateIssueCreated),
+				withTriageState(rs[1], store.TriageStateUpdatedSinceIssueCreation),
+			},
+			want: []*store.CVERecord{
+				func() *store.CVERecord {
+					c := *rs[0]
+					c.TriageState = store.TriageStateUpdatedSinceIssueCreation
+					c.TriageStateReason = "CVE changed; needs issue = true"
+					return &c
+				}(),
+				func() *store.CVERecord {
+					c := *rs[1]
+					c.TriageState = store.TriageStateUpdatedSinceIssueCreation
+					c.TriageStateReason = "CVE changed; needs issue = false"
+					return &c
+				}(),
+				rs[2],
+			},
+		},
+	} {
+		t.Run(test.name, func(t *testing.T) {
+			mstore := store.NewMemStore()
+			createCVERecords(t, mstore, test.cur)
+			if err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
+				t.Fatal(err)
+			}
+			got := mstore.CVERecords()
+			want := map[string]*store.CVERecord{}
+			for _, cr := range test.want {
+				want[cr.ID] = cr
+			}
+			if diff := cmp.Diff(want, got); diff != "" {
+				t.Errorf("mismatch (-want, +got):\n%s", diff)
+			}
+		})
 	}
 }
 
-func newTestCVERecord(t *testing.T, repo *git.Repository, ref *plumbing.Reference, path string, ts store.TriageState) *store.CVERecord {
+func readCVE(t *testing.T, repo *git.Repository, path string) (*cveschema.CVE, string) {
 	blob := findBlob(t, repo, path)
-	r := store.NewCVERecord(readCVE(t, blob), path, blob.Hash.String())
+	var cve cveschema.CVE
+	if err := json.Unmarshal(readBlob(t, blob), &cve); err != nil {
+		t.Fatal(err)
+	}
+	return &cve, blob.Hash.String()
+}
+
+func newTestCVERecord(cve *cveschema.CVE, path, blobHash string, ref *plumbing.Reference, ts store.TriageState) *store.CVERecord {
+	r := store.NewCVERecord(cve, path, blobHash)
 	r.CommitHash = ref.Hash().String()
 	r.TriageState = ts
 	return r
 }
 
-func readCVE(t *testing.T, blob *object.Blob) *cveschema.CVE {
-	var cve cveschema.CVE
-	if err := json.Unmarshal(readBlob(t, blob), &cve); err != nil {
+func createCVERecords(t *testing.T, s store.Store, crs []*store.CVERecord) {
+	err := s.RunTransaction(context.Background(), func(_ context.Context, tx store.Transaction) error {
+		for _, cr := range crs {
+			if err := tx.CreateCVERecord(cr); err != nil {
+				return err
+			}
+		}
+		return nil
+	})
+	if err != nil {
 		t.Fatal(err)
 	}
-	return &cve
 }