internal/worker: update changed CVEs
Add logic to update the DB when a CVE in the repo changes.
Change-Id: Ic03ddaead4e016d45f3cc8c07fe3c1b4d61f7936
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/366874
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/worker/update.go b/internal/worker/update.go
index 63fd8f0..67832b7 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -100,15 +100,6 @@
return st.SetUpdateRecord(ctx, ur)
}
-// Action performed by handleCVE.
-type action int
-
-const (
- nothing action = iota
- add
- mod
-)
-
func updateBatch(ctx context.Context, batch []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, needsIssue triageFunc) (numAdds, numMods int, err error) {
startID := idFromFilename(batch[0].filename)
endID := idFromFilename(batch[len(batch)-1].filename)
@@ -131,14 +122,18 @@
// Determine what needs to be added and modified.
for _, f := range batch {
id := idFromFilename(f.filename)
- act, err := handleCVE(ctx, repo, f, idToRecord[id], commitHash, needsIssue, tx)
+ old := idToRecord[id]
+ if old != nil && old.BlobHash == f.hash.String() {
+ // No change; do nothing.
+ continue
+ }
+ added, err := handleCVE(ctx, repo, f, old, commitHash, needsIssue, tx)
if err != nil {
return err
}
- switch act {
- case add:
+ if added {
numAdds++
- case mod:
+ } else {
numMods++
}
}
@@ -152,48 +147,77 @@
}
// handleCVE determines how to change the store for a single CVE.
-func handleCVE(ctx context.Context, repo *git.Repository, f repoFile, old *store.CVERecord, commitHash plumbing.Hash, needsIssue triageFunc, tx store.Transaction) (_ action, err error) {
+// The CVE will definitely be either added, if it's new, or modified, if it's
+// already in the DB.
+func handleCVE(ctx context.Context, repo *git.Repository, f repoFile, old *store.CVERecord, commitHash plumbing.Hash, needsIssue triageFunc, tx store.Transaction) (added bool, err error) {
defer derrors.Wrap(&err, "handleCVE(%s)", f.filename)
- if old != nil && old.BlobHash == f.hash.String() {
- // No change; do nothing.
- return nothing, nil
- }
// Read CVE from repo.
r, err := blobReader(repo, f.hash)
if err != nil {
- return nothing, err
+ return false, err
}
+ pathname := path.Join(f.dirpath, f.filename)
cve := &cveschema.CVE{}
if err := json.NewDecoder(r).Decode(cve); err != nil {
- log.Printf("ERROR decoding %s: %v", f.filename, err)
- return nothing, nil
+ return false, err
+ }
+ needs := false
+ if cve.State == cveschema.StatePublic {
+ needs, err = needsIssue(cve)
+ if err != nil {
+ return false, err
+ }
}
// If the CVE is not in the database, add it.
if old == nil {
cr := store.NewCVERecord(cve, path.Join(f.dirpath, f.filename), f.hash.String())
cr.CommitHash = commitHash.String()
- needs := false
- if cve.State == cveschema.StatePublic {
- needs, err = needsIssue(cve)
- if err != nil {
- return nothing, err
- }
- }
if needs {
cr.TriageState = store.TriageStateNeedsIssue
} else {
cr.TriageState = store.TriageStateNoActionNeeded
}
if err := tx.CreateCVERecord(cr); err != nil {
- return nothing, err
+ return false, err
}
- return add, nil
- } else {
- // TODO(golang/go#49733): handle changes to CVEs.
+ return true, nil
}
- return nothing, nil
+ // Change to an existing record.
+ mod := *old // copy the old one
+ mod.Path = pathname
+ mod.BlobHash = f.hash.String()
+ mod.CVEState = cve.State
+ mod.CommitHash = commitHash.String()
+ switch old.TriageState {
+ case store.TriageStateNoActionNeeded:
+ if needs {
+ // Didn't need an issue before, does now.
+ mod.TriageState = store.TriageStateNeedsIssue
+ }
+ // Else don't change the triage state, but we still want
+ // to update the other changed fields.
+ case store.TriageStateNeedsIssue:
+ if !needs {
+ // Needed an issue, no longer does.
+ mod.TriageState = store.TriageStateNoActionNeeded
+ }
+ // Else don't change the triage state, but we still want
+ // to update the other changed fields.
+ case store.TriageStateIssueCreated, store.TriageStateUpdatedSinceIssueCreation:
+ // An issue was filed, so a person should revisit this CVE.
+ mod.TriageState = store.TriageStateUpdatedSinceIssueCreation
+ mod.TriageStateReason = fmt.Sprintf("CVE changed; needs issue = %t", needs)
+ // TODO(golang/go#49733): keep a history of the previous states and their commits.
+ default:
+ return false, fmt.Errorf("unknown TriageState: %q", old.TriageState)
+ }
+ // If we're here, then mod is a modification to the DB.
+ if err := tx.SetCVERecord(&mod); err != nil {
+ return false, err
+ }
+ return false, nil
}
type repoFile struct {
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index f68f91c..283c8cf 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -12,7 +12,6 @@
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
- "github.com/go-git/go-git/v5/plumbing/object"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/vuln/internal/cveschema"
@@ -32,6 +31,7 @@
if err != nil {
t.Fatal(err)
}
+
want := []repoFile{
{dirpath: "2021/0xxx", filename: "CVE-2021-0001.json"},
{dirpath: "2021/0xxx", filename: "CVE-2021-0010.json"},
@@ -49,7 +49,6 @@
if err != nil {
t.Fatal(err)
}
- mstore := store.NewMemStore()
h, err := headHash(repo)
if err != nil {
t.Fatal(err)
@@ -57,39 +56,153 @@
needsIssue := func(cve *cveschema.CVE) (bool, error) {
return strings.HasSuffix(cve.ID, "0001"), nil
}
- if err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
- t.Fatal(err)
- }
+
ref, err := repo.Reference(plumbing.HEAD, true)
if err != nil {
t.Fatal(err)
}
- r1 := newTestCVERecord(t, repo, ref, "2021/0xxx/CVE-2021-0001.json", store.TriageStateNeedsIssue)
- r10 := newTestCVERecord(t, repo, ref, "2021/0xxx/CVE-2021-0010.json", store.TriageStateNoActionNeeded)
- r384 := newTestCVERecord(t, repo, ref, "2021/1xxx/CVE-2021-1384.json", store.TriageStateNoActionNeeded)
- wantRecords := map[string]*store.CVERecord{
- "CVE-2021-0001": r1,
- "CVE-2021-0010": r10,
- "CVE-2021-1384": r384,
+ commitHash := ref.Hash().String()
+ const (
+ path1 = "2021/0xxx/CVE-2021-0001.json"
+ path2 = "2021/0xxx/CVE-2021-0010.json"
+ path3 = "2021/1xxx/CVE-2021-1384.json"
+ )
+ cve1, bh1 := readCVE(t, repo, path1)
+ cve2, bh2 := readCVE(t, repo, path2)
+ cve3, bh3 := readCVE(t, repo, path3)
+
+ // CVERecords after the above CVEs are added to an empty DB.
+ rs := []*store.CVERecord{
+ {
+ ID: cve1.ID,
+ CVEState: cve1.State,
+ Path: path1,
+ BlobHash: bh1,
+ CommitHash: commitHash,
+ TriageState: store.TriageStateNeedsIssue, // a public CVE, needsIssue returns true
+ },
+ {
+ ID: cve2.ID,
+ CVEState: cve2.State,
+ Path: path2,
+ BlobHash: bh2,
+ CommitHash: commitHash,
+ TriageState: store.TriageStateNoActionNeeded, // state is reserved
+ },
+ {
+ ID: cve3.ID,
+ CVEState: cve3.State,
+ Path: path3,
+ BlobHash: bh3,
+ CommitHash: commitHash,
+ TriageState: store.TriageStateNoActionNeeded, // state is rejected
+ },
}
- diff := cmp.Diff(wantRecords, mstore.CVERecords())
- if diff != "" {
- t.Errorf("mismatch (-want, +got):\n%s", diff)
+
+ // withTriageState returns a copy of r with the TriageState field changed to ts.
+ withTriageState := func(r *store.CVERecord, ts store.TriageState) *store.CVERecord {
+ c := *r
+ c.BlobHash += "x" // if we don't use a different blob hash, no update will happen
+ c.CommitHash = "?"
+ c.TriageState = ts
+ return &c
+ }
+
+ for _, test := range []struct {
+ name string
+ cur []*store.CVERecord // current state of DB
+ want []*store.CVERecord // expected state after update
+ }{
+ {
+ name: "empty",
+ cur: nil,
+ want: rs,
+ },
+ {
+ name: "no change",
+ cur: rs,
+ want: rs,
+ },
+ {
+ name: "pre-issue changes",
+ cur: []*store.CVERecord{
+ // NoActionNeeded -> NeedsIssue
+ withTriageState(rs[0], store.TriageStateNoActionNeeded),
+ // NeedsIssue -> NoActionNeeded
+ withTriageState(rs[1], store.TriageStateNeedsIssue),
+ // NoActionNeeded, triage state stays the same but other fields change.
+ withTriageState(rs[2], store.TriageStateNoActionNeeded),
+ },
+ want: rs,
+ },
+ {
+ name: "post-issue changes",
+ cur: []*store.CVERecord{
+ // IssueCreated -> Updated
+ withTriageState(rs[0], store.TriageStateIssueCreated),
+ withTriageState(rs[1], store.TriageStateUpdatedSinceIssueCreation),
+ },
+ want: []*store.CVERecord{
+ func() *store.CVERecord {
+ c := *rs[0]
+ c.TriageState = store.TriageStateUpdatedSinceIssueCreation
+ c.TriageStateReason = "CVE changed; needs issue = true"
+ return &c
+ }(),
+ func() *store.CVERecord {
+ c := *rs[1]
+ c.TriageState = store.TriageStateUpdatedSinceIssueCreation
+ c.TriageStateReason = "CVE changed; needs issue = false"
+ return &c
+ }(),
+ rs[2],
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ mstore := store.NewMemStore()
+ createCVERecords(t, mstore, test.cur)
+ if err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
+ t.Fatal(err)
+ }
+ got := mstore.CVERecords()
+ want := map[string]*store.CVERecord{}
+ for _, cr := range test.want {
+ want[cr.ID] = cr
+ }
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("mismatch (-want, +got):\n%s", diff)
+ }
+ })
}
}
-func newTestCVERecord(t *testing.T, repo *git.Repository, ref *plumbing.Reference, path string, ts store.TriageState) *store.CVERecord {
+func readCVE(t *testing.T, repo *git.Repository, path string) (*cveschema.CVE, string) {
blob := findBlob(t, repo, path)
- r := store.NewCVERecord(readCVE(t, blob), path, blob.Hash.String())
+ var cve cveschema.CVE
+ if err := json.Unmarshal(readBlob(t, blob), &cve); err != nil {
+ t.Fatal(err)
+ }
+ return &cve, blob.Hash.String()
+}
+
+func newTestCVERecord(cve *cveschema.CVE, path, blobHash string, ref *plumbing.Reference, ts store.TriageState) *store.CVERecord {
+ r := store.NewCVERecord(cve, path, blobHash)
r.CommitHash = ref.Hash().String()
r.TriageState = ts
return r
}
-func readCVE(t *testing.T, blob *object.Blob) *cveschema.CVE {
- var cve cveschema.CVE
- if err := json.Unmarshal(readBlob(t, blob), &cve); err != nil {
+func createCVERecords(t *testing.T, s store.Store, crs []*store.CVERecord) {
+ err := s.RunTransaction(context.Background(), func(_ context.Context, tx store.Transaction) error {
+ for _, cr := range crs {
+ if err := tx.CreateCVERecord(cr); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ if err != nil {
t.Fatal(err)
}
- return &cve
}