internal/worker: refactor doUpdate

Create a struct for the args instead of passing them down
to multiple functions.

Change-Id: Ifb5798cc2f0f963ac04a48313e8342499b98f51d
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/368858
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/worker/update.go b/internal/worker/update.go
index 07e8640..5b269f2 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -29,18 +29,42 @@
 // A triageFunc triages a CVE: it decides whether an issue needs to be filed.
 type triageFunc func(*cveschema.CVE) (bool, error)
 
-// doUpdate compares the repo at the given commit with the state
-// of the DB and updates the DB to match.
-//
+// An updater performs an update operation on the DB.
+type updater struct {
+	repo       *git.Repository
+	commitHash plumbing.Hash
+	st         store.Store
+	knownIDs   map[string]bool
+	needsIssue triageFunc
+}
+
+// newUpdater creates an updater for updating the store with information from
+// the repo commit.
 // needsIssue determines whether a CVE needs an issue to be filed for it.
-func doUpdate(ctx context.Context, repo *git.Repository, commitHash plumbing.Hash, st store.Store, knownVulnIDs []string, needsIssue triageFunc) (ur *store.CommitUpdateRecord, err error) {
+func newUpdater(repo *git.Repository, commitHash plumbing.Hash, st store.Store, knownVulnIDs []string, needsIssue triageFunc) *updater {
+	u := &updater{
+		repo:       repo,
+		commitHash: commitHash,
+		st:         st,
+		knownIDs:   map[string]bool{},
+		needsIssue: needsIssue,
+	}
+	for _, k := range knownVulnIDs {
+		u.knownIDs[k] = true
+	}
+	return u
+}
+
+// update updates the DB to match the repo at the given commit.
+// It also triages new or changed issues.
+func (u *updater) update(ctx context.Context) (ur *store.CommitUpdateRecord, err error) {
 	// We want the action of reading the old DB record, updating it and
 	// writing it back to be atomic. It would be too expensive to do that one
 	// record at a time. Ideally we'd process the whole repo commit in one
 	// transaction, but Firestore has a limit on how many writes one
 	// transaction can do, so the CVE files in the repo are processed in
 	// batches, one transaction per batch.
-	defer derrors.Wrap(&err, "doUpdate(%s)", commitHash)
+	defer derrors.Wrap(&err, "updater.update(%s)", u.commitHash)
 
 	defer func() {
 		if err != nil {
@@ -54,9 +78,9 @@
 		}
 	}()
 
-	log.Info(ctx, "update starting", event.String("commit", commitHash.String()))
+	log.Info(ctx, "update starting", event.String("commit", u.commitHash.String()))
 
-	commit, err := repo.CommitObject(commitHash)
+	commit, err := u.repo.CommitObject(u.commitHash)
 	if err != nil {
 		return nil, err
 	}
@@ -65,7 +89,7 @@
 	// It is cheaper to read all the files from the repo and compare
 	// them to the DB in bulk, than to walk the repo and process
 	// each file individually.
-	files, err := repoCVEFiles(repo, commit)
+	files, err := repoCVEFiles(u.repo, commit)
 	if err != nil {
 		return nil, err
 	}
@@ -76,51 +100,45 @@
 		return nil, err
 	}
 
-	// Put the known vuln CVE IDs into a map.
-	known := map[string]bool{}
-	for _, k := range knownVulnIDs {
-		known[k] = true
-	}
-
 	// Create a new CommitUpdateRecord to describe this run of doUpdate.
 	ur = &store.CommitUpdateRecord{
 		StartedAt:  time.Now(),
-		CommitHash: commitHash.String(),
+		CommitHash: u.commitHash.String(),
 		CommitTime: commit.Committer.When,
 		NumTotal:   len(files),
 	}
-	if err := st.CreateCommitUpdateRecord(ctx, ur); err != nil {
+	if err := u.st.CreateCommitUpdateRecord(ctx, ur); err != nil {
 		return ur, err
 	}
 
 	for _, dirFiles := range filesByDir {
-		numProc, numAdds, numMods, err := updateDirectory(ctx, dirFiles, st, repo, commitHash, known, needsIssue)
+		numProc, numAdds, numMods, err := u.updateDirectory(ctx, dirFiles)
 		// Change the CommitUpdateRecord in the Store to reflect the results of the directory update.
 		if err != nil {
 			ur.Error = err.Error()
-			if err2 := st.SetCommitUpdateRecord(ctx, ur); err2 != nil {
+			if err2 := u.st.SetCommitUpdateRecord(ctx, ur); err2 != nil {
 				return ur, fmt.Errorf("update failed with %w, could not set update record: %v", err, err2)
 			}
 		}
 		ur.NumProcessed += numProc
 		ur.NumAdded += numAdds
 		ur.NumModified += numMods
-		if err := st.SetCommitUpdateRecord(ctx, ur); err != nil {
+		if err := u.st.SetCommitUpdateRecord(ctx, ur); err != nil {
 			return ur, err
 		}
 	}
 	ur.EndedAt = time.Now()
-	return ur, st.SetCommitUpdateRecord(ctx, ur)
+	return ur, u.st.SetCommitUpdateRecord(ctx, ur)
 }
 
-func updateDirectory(ctx context.Context, dirFiles []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, knownIDs map[string]bool, needsIssue triageFunc) (numProc, numAdds, numMods int, err error) {
+func (u *updater) updateDirectory(ctx context.Context, dirFiles []repoFile) (numProc, numAdds, numMods int, err error) {
 	dirPath := dirFiles[0].dirPath
 	dirHash := dirFiles[0].treeHash.String()
 
 	// A non-empty directory hash means that we have fully processed the directory
 	// with that hash. If the stored hash matches the current one, we can skip
 	// this directory.
-	dbHash, err := st.GetDirectoryHash(ctx, dirPath)
+	dbHash, err := u.st.GetDirectoryHash(ctx, dirPath)
 	if err != nil {
 		return 0, 0, 0, err
 	}
@@ -129,7 +147,7 @@
 		return 0, 0, 0, nil
 	}
 	// Set the hash to something that can't match, until we fully process this directory.
-	if err := st.SetDirectoryHash(ctx, dirPath, "in progress"); err != nil {
+	if err := u.st.SetDirectoryHash(ctx, dirPath, "in progress"); err != nil {
 		return 0, 0, 0, err
 	}
 	// It's okay if we crash now; the directory hashes are just an optimization.
@@ -146,7 +164,7 @@
 		if j > len(dirFiles) {
 			j = len(dirFiles)
 		}
-		numBatchAdds, numBatchMods, err := updateBatch(ctx, dirFiles[i:j], st, repo, commitHash, knownIDs, needsIssue)
+		numBatchAdds, numBatchMods, err := u.updateBatch(ctx, dirFiles[i:j])
 		if err != nil {
 			return 0, 0, 0, err
 		}
@@ -158,18 +176,18 @@
 	} // end batch loop
 
 	// We're done with this directory, so we can remember its hash.
-	if err := st.SetDirectoryHash(ctx, dirPath, dirHash); err != nil {
+	if err := u.st.SetDirectoryHash(ctx, dirPath, dirHash); err != nil {
 		return 0, 0, 0, err
 	}
 	return numProc, numAdds, numMods, nil
 }
 
-func updateBatch(ctx context.Context, batch []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, knownIDs map[string]bool, needsIssue triageFunc) (numAdds, numMods int, err error) {
+func (u *updater) updateBatch(ctx context.Context, batch []repoFile) (numAdds, numMods int, err error) {
 	startID := idFromFilename(batch[0].filename)
 	endID := idFromFilename(batch[len(batch)-1].filename)
 	defer derrors.Wrap(&err, "updateBatch(%s-%s)", startID, endID)
 
-	err = st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
+	err = u.st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
 		numAdds = 0
 		numMods = 0
 
@@ -192,7 +210,7 @@
 				// No change; do nothing.
 				continue
 			}
-			added, err := handleCVE(repo, f, old, commitHash, knownIDs, needsIssue, tx)
+			added, err := u.handleCVE(f, old, tx)
 			if err != nil {
 				return err
 			}
@@ -218,11 +236,11 @@
 // handleCVE determines how to change the store for a single CVE.
 // The CVE will definitely be either added, if it's new, or modified, if it's
 // already in the DB.
-func handleCVE(repo *git.Repository, f repoFile, old *store.CVERecord, commitHash plumbing.Hash, knownIDs map[string]bool, needsIssue triageFunc, tx store.Transaction) (added bool, err error) {
+func (u *updater) handleCVE(f repoFile, old *store.CVERecord, tx store.Transaction) (added bool, err error) {
 	defer derrors.Wrap(&err, "handleCVE(%s)", f.filename)
 
 	// Read CVE from repo.
-	r, err := blobReader(repo, f.blobHash)
+	r, err := blobReader(u.repo, f.blobHash)
 	if err != nil {
 		return false, err
 	}
@@ -232,8 +250,8 @@
 		return false, err
 	}
 	needs := false
-	if cve.State == cveschema.StatePublic && !knownIDs[cve.ID] {
-		needs, err = needsIssue(cve)
+	if cve.State == cveschema.StatePublic && !u.knownIDs[cve.ID] {
+		needs, err = u.needsIssue(cve)
 		if err != nil {
 			return false, err
 		}
@@ -242,7 +260,7 @@
 	// If the CVE is not in the database, add it.
 	if old == nil {
 		cr := store.NewCVERecord(cve, pathname, f.blobHash.String())
-		cr.CommitHash = commitHash.String()
+		cr.CommitHash = u.commitHash.String()
 		if needs {
 			cr.TriageState = store.TriageStateNeedsIssue
 		} else {
@@ -258,7 +276,7 @@
 	mod.Path = pathname
 	mod.BlobHash = f.blobHash.String()
 	mod.CVEState = cve.State
-	mod.CommitHash = commitHash.String()
+	mod.CommitHash = u.commitHash.String()
 	switch old.TriageState {
 	case store.TriageStateNoActionNeeded:
 		if needs {
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index b320e82..5bf8e91 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -167,7 +167,7 @@
 		t.Run(test.name, func(t *testing.T) {
 			mstore := store.NewMemStore()
 			createCVERecords(t, mstore, test.cur)
-			if _, err := doUpdate(ctx, repo, h, mstore, knownVulns, needsIssue); err != nil {
+			if _, err := newUpdater(repo, h, mstore, knownVulns, needsIssue).update(ctx); err != nil {
 				t.Fatal(err)
 			}
 			got := mstore.CVERecords()
diff --git a/internal/worker/worker.go b/internal/worker/worker.go
index ea72716..d2d2c59 100644
--- a/internal/worker/worker.go
+++ b/internal/worker/worker.go
@@ -42,9 +42,10 @@
 	if err != nil {
 		return err
 	}
-	_, err = doUpdate(ctx, repo, ch, st, knownVulnIDs, func(cve *cveschema.CVE) (bool, error) {
+	u := newUpdater(repo, ch, st, knownVulnIDs, func(cve *cveschema.CVE) (bool, error) {
 		return TriageCVE(ctx, cve, pkgsiteURL)
 	})
+	_, err = u.update(ctx)
 	return err
 }