internal/worker: refactor doUpdate
Create a struct for the args instead of passing them down
to multiple functions.
Change-Id: Ifb5798cc2f0f963ac04a48313e8342499b98f51d
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/368858
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/worker/update.go b/internal/worker/update.go
index 07e8640..5b269f2 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -29,18 +29,42 @@
// A triageFunc triages a CVE: it decides whether an issue needs to be filed.
type triageFunc func(*cveschema.CVE) (bool, error)
-// doUpdate compares the repo at the given commit with the state
-// of the DB and updates the DB to match.
-//
+// An updater performs an update operation on the DB.
+type updater struct {
+ repo *git.Repository
+ commitHash plumbing.Hash
+ st store.Store
+ knownIDs map[string]bool
+ needsIssue triageFunc
+}
+
+// newUpdater creates an updater for updating the store with information from
+// the repo commit.
// needsIssue determines whether a CVE needs an issue to be filed for it.
-func doUpdate(ctx context.Context, repo *git.Repository, commitHash plumbing.Hash, st store.Store, knownVulnIDs []string, needsIssue triageFunc) (ur *store.CommitUpdateRecord, err error) {
+func newUpdater(repo *git.Repository, commitHash plumbing.Hash, st store.Store, knownVulnIDs []string, needsIssue triageFunc) *updater {
+ u := &updater{
+ repo: repo,
+ commitHash: commitHash,
+ st: st,
+ knownIDs: map[string]bool{},
+ needsIssue: needsIssue,
+ }
+ for _, k := range knownVulnIDs {
+ u.knownIDs[k] = true
+ }
+ return u
+}
+
+// update updates the DB to match the repo at the given commit.
+// It also triages new or changed issues.
+func (u *updater) update(ctx context.Context) (ur *store.CommitUpdateRecord, err error) {
// We want the action of reading the old DB record, updating it and
// writing it back to be atomic. It would be too expensive to do that one
// record at a time. Ideally we'd process the whole repo commit in one
// transaction, but Firestore has a limit on how many writes one
// transaction can do, so the CVE files in the repo are processed in
// batches, one transaction per batch.
- defer derrors.Wrap(&err, "doUpdate(%s)", commitHash)
+ defer derrors.Wrap(&err, "updater.update(%s)", u.commitHash)
defer func() {
if err != nil {
@@ -54,9 +78,9 @@
}
}()
- log.Info(ctx, "update starting", event.String("commit", commitHash.String()))
+ log.Info(ctx, "update starting", event.String("commit", u.commitHash.String()))
- commit, err := repo.CommitObject(commitHash)
+ commit, err := u.repo.CommitObject(u.commitHash)
if err != nil {
return nil, err
}
@@ -65,7 +89,7 @@
// It is cheaper to read all the files from the repo and compare
// them to the DB in bulk, than to walk the repo and process
// each file individually.
- files, err := repoCVEFiles(repo, commit)
+ files, err := repoCVEFiles(u.repo, commit)
if err != nil {
return nil, err
}
@@ -76,51 +100,45 @@
return nil, err
}
- // Put the known vuln CVE IDs into a map.
- known := map[string]bool{}
- for _, k := range knownVulnIDs {
- known[k] = true
- }
-
// Create a new CommitUpdateRecord to describe this run of doUpdate.
ur = &store.CommitUpdateRecord{
StartedAt: time.Now(),
- CommitHash: commitHash.String(),
+ CommitHash: u.commitHash.String(),
CommitTime: commit.Committer.When,
NumTotal: len(files),
}
- if err := st.CreateCommitUpdateRecord(ctx, ur); err != nil {
+ if err := u.st.CreateCommitUpdateRecord(ctx, ur); err != nil {
return ur, err
}
for _, dirFiles := range filesByDir {
- numProc, numAdds, numMods, err := updateDirectory(ctx, dirFiles, st, repo, commitHash, known, needsIssue)
+ numProc, numAdds, numMods, err := u.updateDirectory(ctx, dirFiles)
// Change the CommitUpdateRecord in the Store to reflect the results of the directory update.
if err != nil {
ur.Error = err.Error()
- if err2 := st.SetCommitUpdateRecord(ctx, ur); err2 != nil {
+ if err2 := u.st.SetCommitUpdateRecord(ctx, ur); err2 != nil {
return ur, fmt.Errorf("update failed with %w, could not set update record: %v", err, err2)
}
}
ur.NumProcessed += numProc
ur.NumAdded += numAdds
ur.NumModified += numMods
- if err := st.SetCommitUpdateRecord(ctx, ur); err != nil {
+ if err := u.st.SetCommitUpdateRecord(ctx, ur); err != nil {
return ur, err
}
}
ur.EndedAt = time.Now()
- return ur, st.SetCommitUpdateRecord(ctx, ur)
+ return ur, u.st.SetCommitUpdateRecord(ctx, ur)
}
-func updateDirectory(ctx context.Context, dirFiles []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, knownIDs map[string]bool, needsIssue triageFunc) (numProc, numAdds, numMods int, err error) {
+func (u *updater) updateDirectory(ctx context.Context, dirFiles []repoFile) (numProc, numAdds, numMods int, err error) {
dirPath := dirFiles[0].dirPath
dirHash := dirFiles[0].treeHash.String()
// A non-empty directory hash means that we have fully processed the directory
// with that hash. If the stored hash matches the current one, we can skip
// this directory.
- dbHash, err := st.GetDirectoryHash(ctx, dirPath)
+ dbHash, err := u.st.GetDirectoryHash(ctx, dirPath)
if err != nil {
return 0, 0, 0, err
}
@@ -129,7 +147,7 @@
return 0, 0, 0, nil
}
// Set the hash to something that can't match, until we fully process this directory.
- if err := st.SetDirectoryHash(ctx, dirPath, "in progress"); err != nil {
+ if err := u.st.SetDirectoryHash(ctx, dirPath, "in progress"); err != nil {
return 0, 0, 0, err
}
// It's okay if we crash now; the directory hashes are just an optimization.
@@ -146,7 +164,7 @@
if j > len(dirFiles) {
j = len(dirFiles)
}
- numBatchAdds, numBatchMods, err := updateBatch(ctx, dirFiles[i:j], st, repo, commitHash, knownIDs, needsIssue)
+ numBatchAdds, numBatchMods, err := u.updateBatch(ctx, dirFiles[i:j])
if err != nil {
return 0, 0, 0, err
}
@@ -158,18 +176,18 @@
} // end batch loop
// We're done with this directory, so we can remember its hash.
- if err := st.SetDirectoryHash(ctx, dirPath, dirHash); err != nil {
+ if err := u.st.SetDirectoryHash(ctx, dirPath, dirHash); err != nil {
return 0, 0, 0, err
}
return numProc, numAdds, numMods, nil
}
-func updateBatch(ctx context.Context, batch []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, knownIDs map[string]bool, needsIssue triageFunc) (numAdds, numMods int, err error) {
+func (u *updater) updateBatch(ctx context.Context, batch []repoFile) (numAdds, numMods int, err error) {
startID := idFromFilename(batch[0].filename)
endID := idFromFilename(batch[len(batch)-1].filename)
defer derrors.Wrap(&err, "updateBatch(%s-%s)", startID, endID)
- err = st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
+ err = u.st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
numAdds = 0
numMods = 0
@@ -192,7 +210,7 @@
// No change; do nothing.
continue
}
- added, err := handleCVE(repo, f, old, commitHash, knownIDs, needsIssue, tx)
+ added, err := u.handleCVE(f, old, tx)
if err != nil {
return err
}
@@ -218,11 +236,11 @@
// handleCVE determines how to change the store for a single CVE.
// The CVE will definitely be either added, if it's new, or modified, if it's
// already in the DB.
-func handleCVE(repo *git.Repository, f repoFile, old *store.CVERecord, commitHash plumbing.Hash, knownIDs map[string]bool, needsIssue triageFunc, tx store.Transaction) (added bool, err error) {
+func (u *updater) handleCVE(f repoFile, old *store.CVERecord, tx store.Transaction) (added bool, err error) {
defer derrors.Wrap(&err, "handleCVE(%s)", f.filename)
// Read CVE from repo.
- r, err := blobReader(repo, f.blobHash)
+ r, err := blobReader(u.repo, f.blobHash)
if err != nil {
return false, err
}
@@ -232,8 +250,8 @@
return false, err
}
needs := false
- if cve.State == cveschema.StatePublic && !knownIDs[cve.ID] {
- needs, err = needsIssue(cve)
+ if cve.State == cveschema.StatePublic && !u.knownIDs[cve.ID] {
+ needs, err = u.needsIssue(cve)
if err != nil {
return false, err
}
@@ -242,7 +260,7 @@
// If the CVE is not in the database, add it.
if old == nil {
cr := store.NewCVERecord(cve, pathname, f.blobHash.String())
- cr.CommitHash = commitHash.String()
+ cr.CommitHash = u.commitHash.String()
if needs {
cr.TriageState = store.TriageStateNeedsIssue
} else {
@@ -258,7 +276,7 @@
mod.Path = pathname
mod.BlobHash = f.blobHash.String()
mod.CVEState = cve.State
- mod.CommitHash = commitHash.String()
+ mod.CommitHash = u.commitHash.String()
switch old.TriageState {
case store.TriageStateNoActionNeeded:
if needs {
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index b320e82..5bf8e91 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -167,7 +167,7 @@
t.Run(test.name, func(t *testing.T) {
mstore := store.NewMemStore()
createCVERecords(t, mstore, test.cur)
- if _, err := doUpdate(ctx, repo, h, mstore, knownVulns, needsIssue); err != nil {
+ if _, err := newUpdater(repo, h, mstore, knownVulns, needsIssue).update(ctx); err != nil {
t.Fatal(err)
}
got := mstore.CVERecords()
diff --git a/internal/worker/worker.go b/internal/worker/worker.go
index ea72716..d2d2c59 100644
--- a/internal/worker/worker.go
+++ b/internal/worker/worker.go
@@ -42,9 +42,10 @@
if err != nil {
return err
}
- _, err = doUpdate(ctx, repo, ch, st, knownVulnIDs, func(cve *cveschema.CVE) (bool, error) {
+ u := newUpdater(repo, ch, st, knownVulnIDs, func(cve *cveschema.CVE) (bool, error) {
return TriageCVE(ctx, cve, pkgsiteURL)
})
+ _, err = u.update(ctx)
return err
}