internal/worker: optimize updates by comparing tree hashes

We can avoid processing entire directories of the CVE repo
if we remember the hash of each directory and compare it
to the new hash.

This optimization dramatically speeds up processing a new commit,
because most files and directories are unchanged.

Change-Id: Ie3e78fefc720af0e8cd1864f3a352b1737a99f45
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/368597
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/go.mod b/go.mod
index 70f6e5d..fddd1f5 100644
--- a/go.mod
+++ b/go.mod
@@ -44,7 +44,7 @@
 	google.golang.org/api v0.60.0
 	google.golang.org/appengine v1.6.7 // indirect
 	google.golang.org/genproto v0.0.0-20211028162531-8db9c33dc351 // indirect
-	google.golang.org/grpc v1.40.0 // indirect
+	google.golang.org/grpc v1.40.0
 	google.golang.org/protobuf v1.27.1 // indirect
 	gopkg.in/warnings.v0 v0.1.2 // indirect
 	gopkg.in/yaml.v2 v2.4.0
diff --git a/internal/worker/store/fire_store.go b/internal/worker/store/fire_store.go
index 50eb5dc..3c7b037 100644
--- a/internal/worker/store/fire_store.go
+++ b/internal/worker/store/fire_store.go
@@ -7,10 +7,14 @@
 import (
 	"context"
 	"errors"
+	"fmt"
+	"strings"
 
 	"cloud.google.com/go/firestore"
 	"golang.org/x/vuln/internal/derrors"
 	"google.golang.org/api/iterator"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 // FireStore is a Store implemented with Google Cloud Firestore.
@@ -25,6 +29,7 @@
 // are some collections:
 // - CVEs for CVERecords
 // - CommitUpdates for CommitUpdateRecords
+// - DirHashes for directory hashes
 type FireStore struct {
 	namespace string
 	client    *firestore.Client
@@ -35,6 +40,7 @@
 	namespaceCollection = "Namespaces"
 	updateCollection    = "Updates"
 	cveCollection       = "CVEs"
+	dirHashCollection   = "DirHashes"
 )
 
 // NewFireStore creates a new FireStore, backed by a client to Firestore. Since
@@ -108,6 +114,46 @@
 	return urs, nil
 }
 
+type dirHash struct {
+	Hash string
+}
+
+// dirHashRef returns a DocumentRef for the directory dir.
+func (s *FireStore) dirHashRef(dir string) *firestore.DocumentRef {
+	// Firestore IDs cannot contain slashes.
+	// Do something simple and readable to fix that.
+	id := strings.ReplaceAll(dir, "/", "|")
+	return s.nsDoc.Collection(dirHashCollection).Doc(id)
+}
+
+// GetDirectoryHash implements Transaction.GetDirectoryHash.
+func (fs *FireStore) GetDirectoryHash(ctx context.Context, dir string) (_ string, err error) {
+	defer derrors.Wrap(&err, "GetDirectoryHash(%s)", dir)
+
+	ds, err := fs.dirHashRef(dir).Get(ctx)
+	if err != nil {
+		if status.Code(err) == codes.NotFound {
+			return "", nil
+		}
+		return "", err
+	}
+	data, err := ds.DataAt("Hash")
+	if err != nil {
+		return "", err
+	}
+	hash, ok := data.(string)
+	if !ok {
+		return "", fmt.Errorf("hash data for %s is not a string", dir)
+	}
+	return hash, nil
+}
+
+// SetDirectoryHash implements Transaction.SetDirectoryHash.
+func (fs *FireStore) SetDirectoryHash(ctx context.Context, dir, hash string) error {
+	_, err := fs.dirHashRef(dir).Set(ctx, dirHash{Hash: hash})
+	return err
+}
+
 // RunTransaction implements Store.RunTransaction.
 func (fs *FireStore) RunTransaction(ctx context.Context, f func(context.Context, Transaction) error) error {
 	return fs.client.RunTransaction(ctx,
diff --git a/internal/worker/store/mem_store.go b/internal/worker/store/mem_store.go
index 2e8b1ac..06b72e7 100644
--- a/internal/worker/store/mem_store.go
+++ b/internal/worker/store/mem_store.go
@@ -19,6 +19,7 @@
 	mu            sync.Mutex
 	cveRecords    map[string]*CVERecord
 	updateRecords map[string]*CommitUpdateRecord
+	dirHashes     map[string]string
 }
 
 // NewMemStore creates a new, empty MemStore.
@@ -32,6 +33,7 @@
 func (ms *MemStore) Clear(context.Context) error {
 	ms.cveRecords = map[string]*CVERecord{}
 	ms.updateRecords = map[string]*CommitUpdateRecord{}
+	ms.dirHashes = map[string]string{}
 	return nil
 }
 
@@ -76,6 +78,17 @@
 	return urs, nil
 }
 
+// GetDirectoryHash implements Transaction.GetDirectoryHash.
+func (ms *MemStore) GetDirectoryHash(_ context.Context, dir string) (string, error) {
+	return ms.dirHashes[dir], nil
+}
+
+// SetDirectoryHash implements Transaction.SetDirectoryHash.
+func (ms *MemStore) SetDirectoryHash(_ context.Context, dir, hash string) error {
+	ms.dirHashes[dir] = hash
+	return nil
+}
+
 // RunTransaction implements Store.RunTransaction.
 // A transaction runs with a single lock on the entire DB.
 func (ms *MemStore) RunTransaction(ctx context.Context, f func(context.Context, Transaction) error) error {
diff --git a/internal/worker/store/store.go b/internal/worker/store/store.go
index ab6f80a..c049474 100644
--- a/internal/worker/store/store.go
+++ b/internal/worker/store/store.go
@@ -135,6 +135,13 @@
 	// least recent.
 	ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error)
 
+	// GetDirectoryHash returns the hash for the tree object corresponding to dir.
+	// If dir isn't found, it succeeds with the empty string.
+	GetDirectoryHash(ctx context.Context, dir string) (string, error)
+
+	// SetDirectoryHash sets the hash for the given directory.
+	SetDirectoryHash(ctx context.Context, dir, hash string) error
+
 	// RunTransaction runs the function in a transaction.
 	RunTransaction(context.Context, func(context.Context, Transaction) error) error
 }
diff --git a/internal/worker/store/store_test.go b/internal/worker/store/store_test.go
index 777c475..e68b805 100644
--- a/internal/worker/store/store_test.go
+++ b/internal/worker/store/store_test.go
@@ -28,6 +28,9 @@
 	t.Run("CVEs", func(t *testing.T) {
 		testCVEs(t, s)
 	})
+	t.Run("DirHashes", func(t *testing.T) {
+		testDirHashes(t, s)
+	})
 }
 
 func testUpdates(t *testing.T, s Store) {
@@ -148,6 +151,29 @@
 	diff(t, &want, got)
 }
 
+func testDirHashes(t *testing.T, s Store) {
+	ctx := context.Background()
+	const dir = "a/b/c"
+	got, err := s.GetDirectoryHash(ctx, dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got != "" {
+		t.Fatalf("got %q, want empty", got)
+	}
+	const want = "123"
+	if err := s.SetDirectoryHash(ctx, "a/b/c", want); err != nil {
+		t.Fatal(err)
+	}
+	got, err = s.GetDirectoryHash(ctx, dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got != want {
+		t.Fatalf("got %q, want %q", got, want)
+	}
+}
+
 func createCVERecords(t *testing.T, ctx context.Context, s Store, crs []*CVERecord) {
 	err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		for _, cr := range crs {
diff --git a/internal/worker/update.go b/internal/worker/update.go
index 7fe48ea..e2d87db 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -11,6 +11,7 @@
 	"io"
 	"path"
 	"sort"
+	"strconv"
 	"strings"
 	"time"
 
@@ -32,7 +33,7 @@
 // of the DB and updates the DB to match.
 //
 // needsIssue determines whether a CVE needs an issue to be filed for it.
-func doUpdate(ctx context.Context, repo *git.Repository, commitHash plumbing.Hash, st store.Store, needsIssue triageFunc) (err error) {
+func doUpdate(ctx context.Context, repo *git.Repository, commitHash plumbing.Hash, st store.Store, needsIssue triageFunc) (ur *store.CommitUpdateRecord, err error) {
 	// We want the action of reading the old DB record, updating it and
 	// writing it back to be atomic. It would be too expensive to do that one
 	// record at a time. Ideally we'd process the whole repo commit in one
@@ -45,7 +46,11 @@
 		if err != nil {
 			log.Error(ctx, "update failed", event.Value("error", err))
 		} else {
-			log.Info(ctx, "update succeeded")
+			nProcessed := int64(0)
+			if ur != nil {
+				nProcessed = int64(ur.NumProcessed)
+			}
+			log.Info(ctx, "update succeeded", event.Int64("numProcessed", nProcessed))
 		}
 	}()
 
@@ -53,7 +58,7 @@
 
 	commit, err := repo.CommitObject(commitHash)
 	if err != nil {
-		return err
+		return nil, err
 	}
 
 	// Get all the CVE files.
@@ -62,52 +67,95 @@
 	// each file individually.
 	files, err := repoCVEFiles(repo, commit)
 	if err != nil {
-		return err
+		return nil, err
 	}
+	// Process files in the same directory together, so we can easily skip
+	// the entire directory if it hasn't changed.
+	filesByDir, err := groupFilesByDirectory(files)
+	if err != nil {
+		return nil, err
+	}
+
 	// Create a new CommitUpdateRecord to describe this run of doUpdate.
-	ur := &store.CommitUpdateRecord{
+	ur = &store.CommitUpdateRecord{
 		StartedAt:  time.Now(),
 		CommitHash: commitHash.String(),
 		CommitTime: commit.Committer.When,
 		NumTotal:   len(files),
 	}
 	if err := st.CreateCommitUpdateRecord(ctx, ur); err != nil {
-		return err
+		return ur, err
 	}
 
-	// Update files in batches.
-
-	// Max Firestore writes per transaction.
-	// See https://cloud.google.com/firestore/quotas.
-	const batchSize = 500
-
-	for i := 0; i < len(files); i += batchSize {
-		j := i + batchSize
-		if j > len(files) {
-			j = len(files)
-		}
-		numAdds, numMods, err := updateBatch(ctx, files[i:j], st, repo, commitHash, needsIssue)
-
-		// Change the CommitUpdateRecord in the Store to reflect the results of the transaction.
+	for _, dirFiles := range filesByDir {
+		numProc, numAdds, numMods, err := updateDirectory(ctx, dirFiles, st, repo, commitHash, needsIssue)
+		// Change the CommitUpdateRecord in the Store to reflect the results of the directory update.
 		if err != nil {
 			ur.Error = err.Error()
 			if err2 := st.SetCommitUpdateRecord(ctx, ur); err2 != nil {
-				return fmt.Errorf("update failed with %w, could not set update record: %v", err, err2)
+				return ur, fmt.Errorf("update failed with %w, could not set update record: %v", err, err2)
 			}
-			return err
 		}
-		ur.NumProcessed += j - i
-		// Add in these two numbers here, instead of in the function passed to
-		// RunTransaction, because that function may be executed multiple times.
+		ur.NumProcessed += numProc
 		ur.NumAdded += numAdds
 		ur.NumModified += numMods
 		if err := st.SetCommitUpdateRecord(ctx, ur); err != nil {
-			return err
+			return ur, err
 		}
-	} // end loop
-
+	}
 	ur.EndedAt = time.Now()
-	return st.SetCommitUpdateRecord(ctx, ur)
+	return ur, st.SetCommitUpdateRecord(ctx, ur)
+}
+
+func updateDirectory(ctx context.Context, dirFiles []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, needsIssue triageFunc) (numProc, numAdds, numMods int, err error) {
+	dirPath := dirFiles[0].dirPath
+	dirHash := dirFiles[0].treeHash.String()
+
+	// A non-empty directory hash means that we have fully processed the directory
+	// with that hash. If the stored hash matches the current one, we can skip
+	// this directory.
+	dbHash, err := st.GetDirectoryHash(ctx, dirPath)
+	if err != nil {
+		return 0, 0, 0, err
+	}
+	if dirHash == dbHash {
+		log.Infof(ctx, "skipping directory %s because the hashes match", dirPath)
+		return 0, 0, 0, nil
+	}
+	// Set the hash to something that can't match, until we fully process this directory.
+	if err := st.SetDirectoryHash(ctx, dirPath, "in progress"); err != nil {
+		return 0, 0, 0, err
+	}
+	// It's okay if we crash now; the directory hashes are just an optimization.
+	// At worst we'll redo this directory next time.
+
+	// Update files in batches.
+
+	// Firestore supports a maximum of 500 writes per transaction.
+	// See https://cloud.google.com/firestore/quotas.
+	const batchSize = 500
+
+	for i := 0; i < len(dirFiles); i += batchSize {
+		j := i + batchSize
+		if j > len(dirFiles) {
+			j = len(dirFiles)
+		}
+		numBatchAdds, numBatchMods, err := updateBatch(ctx, dirFiles[i:j], st, repo, commitHash, needsIssue)
+		if err != nil {
+			return 0, 0, 0, err
+		}
+		numProc += j - i
+		// Add in these two numbers here, instead of in the function passed to
+		// RunTransaction, because that function may be executed multiple times.
+		numAdds += numBatchAdds
+		numMods += numBatchMods
+	} // end batch loop
+
+	// We're done with this directory, so we can remember its hash.
+	if err := st.SetDirectoryHash(ctx, dirPath, dirHash); err != nil {
+		return 0, 0, 0, err
+	}
+	return numProc, numAdds, numMods, nil
 }
 
 func updateBatch(ctx context.Context, batch []repoFile, st store.Store, repo *git.Repository, commitHash plumbing.Hash, needsIssue triageFunc) (numAdds, numMods int, err error) {
@@ -118,6 +166,7 @@
 	err = st.RunTransaction(ctx, func(ctx context.Context, tx store.Transaction) error {
 		numAdds = 0
 		numMods = 0
+
 		// Read information about the existing state in the store that's
 		// relevant to this batch. Since the entries are sorted, we can read
 		// a range of IDS.
@@ -133,7 +182,7 @@
 		for _, f := range batch {
 			id := idFromFilename(f.filename)
 			old := idToRecord[id]
-			if old != nil && old.BlobHash == f.hash.String() {
+			if old != nil && old.BlobHash == f.blobHash.String() {
 				// No change; do nothing.
 				continue
 			}
@@ -167,11 +216,11 @@
 	defer derrors.Wrap(&err, "handleCVE(%s)", f.filename)
 
 	// Read CVE from repo.
-	r, err := blobReader(repo, f.hash)
+	r, err := blobReader(repo, f.blobHash)
 	if err != nil {
 		return false, err
 	}
-	pathname := path.Join(f.dirpath, f.filename)
+	pathname := path.Join(f.dirPath, f.filename)
 	cve := &cveschema.CVE{}
 	if err := json.NewDecoder(r).Decode(cve); err != nil {
 		return false, err
@@ -186,7 +235,7 @@
 
 	// If the CVE is not in the database, add it.
 	if old == nil {
-		cr := store.NewCVERecord(cve, path.Join(f.dirpath, f.filename), f.hash.String())
+		cr := store.NewCVERecord(cve, pathname, f.blobHash.String())
 		cr.CommitHash = commitHash.String()
 		if needs {
 			cr.TriageState = store.TriageStateNeedsIssue
@@ -201,7 +250,7 @@
 	// Change to an existing record.
 	mod := *old // copy the old one
 	mod.Path = pathname
-	mod.BlobHash = f.hash.String()
+	mod.BlobHash = f.blobHash.String()
 	mod.CVEState = cve.State
 	mod.CommitHash = commitHash.String()
 	switch old.TriageState {
@@ -235,9 +284,12 @@
 }
 
 type repoFile struct {
-	dirpath  string
+	dirPath  string
 	filename string
-	hash     plumbing.Hash
+	treeHash plumbing.Hash
+	blobHash plumbing.Hash
+	year     int
+	number   int
 }
 
 // repoCVEFiles returns all the CVE files in the given repo commit, sorted by
@@ -254,11 +306,19 @@
 		return nil, err
 	}
 	sort.Slice(files, func(i, j int) bool {
-		return files[i].filename < files[j].filename
+		// Compare the year and the number, as ints. Using the ID directly
+		// would put CVE-2014-100009 before CVE-2014-10001.
+		if files[i].year != files[j].year {
+			return files[i].year < files[j].year
+		}
+		return files[i].number < files[j].number
 	})
 	return files, nil
 }
 
+// func cmpIDs(id1, i2 string) int  {
+// 	toInts := func(id string) [2]int {
+
 // walkFiles collects CVE files from a repo tree.
 func walkFiles(repo *git.Repository, tree *object.Tree, dirpath string, files []repoFile) ([]repoFile, error) {
 	for _, e := range tree.Entries {
@@ -272,16 +332,58 @@
 				return nil, err
 			}
 		} else if isCVEFilename(e.Name) {
+			// e.Name is CVE-YEAR-NUMBER.json
+			year, err := strconv.Atoi(e.Name[4:8])
+			if err != nil {
+				return nil, err
+			}
+			number, err := strconv.Atoi(e.Name[9 : len(e.Name)-5])
+			if err != nil {
+				return nil, err
+			}
 			files = append(files, repoFile{
-				dirpath:  dirpath,
+				dirPath:  dirpath,
 				filename: e.Name,
-				hash:     e.Hash,
+				treeHash: tree.Hash,
+				blobHash: e.Hash,
+				year:     year,
+				number:   number,
 			})
 		}
 	}
 	return files, nil
 }
 
+// Collect files by directory, verifying that directories are contiguous in
+// the list of files. Our directory hash optimization depends on that.
+func groupFilesByDirectory(files []repoFile) ([][]repoFile, error) {
+	if len(files) == 0 {
+		return nil, nil
+	}
+	var (
+		result [][]repoFile
+		curDir []repoFile
+	)
+	for _, f := range files {
+		if len(curDir) > 0 && f.dirPath != curDir[0].dirPath {
+			result = append(result, curDir)
+			curDir = nil
+		}
+		curDir = append(curDir, f)
+	}
+	if len(curDir) > 0 {
+		result = append(result, curDir)
+	}
+	seen := map[string]bool{}
+	for _, dir := range result {
+		if seen[dir[0].dirPath] {
+			return nil, fmt.Errorf("directory %s is not contiguous in the sorted list of files", dir[0].dirPath)
+		}
+		seen[dir[0].dirPath] = true
+	}
+	return result, nil
+}
+
 // blobReader returns a reader to the blob with the given hash.
 func blobReader(repo *git.Repository, hash plumbing.Hash) (io.Reader, error) {
 	blob, err := repo.BlobObject(hash)
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index a0d844e..5e7f517 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -35,12 +35,13 @@
 	}
 
 	want := []repoFile{
-		{dirpath: "2021/0xxx", filename: "CVE-2021-0001.json"},
-		{dirpath: "2021/0xxx", filename: "CVE-2021-0010.json"},
-		{dirpath: "2021/1xxx", filename: "CVE-2021-1384.json"},
+		{dirPath: "2021/0xxx", filename: "CVE-2021-0001.json", year: 2021, number: 1},
+		{dirPath: "2021/0xxx", filename: "CVE-2021-0010.json", year: 2021, number: 10},
+		{dirPath: "2021/1xxx", filename: "CVE-2021-1384.json", year: 2021, number: 1384},
 	}
 
-	if diff := cmp.Diff(want, got, cmp.AllowUnexported(repoFile{}), cmpopts.IgnoreFields(repoFile{}, "hash")); diff != "" {
+	opt := cmpopts.IgnoreFields(repoFile{}, "treeHash", "blobHash")
+	if diff := cmp.Diff(want, got, cmp.AllowUnexported(repoFile{}), opt); diff != "" {
 		t.Errorf("mismatch (-want, +got):\n%s", diff)
 	}
 }
@@ -165,7 +166,7 @@
 		t.Run(test.name, func(t *testing.T) {
 			mstore := store.NewMemStore()
 			createCVERecords(t, mstore, test.cur)
-			if err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
+			if _, err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
 				t.Fatal(err)
 			}
 			got := mstore.CVERecords()
@@ -180,6 +181,64 @@
 	}
 }
 
+func TestGroupFilesByDirectory(t *testing.T) {
+	for _, test := range []struct {
+		in   []repoFile
+		want [][]repoFile
+	}{
+		{in: nil, want: nil},
+		{
+			in:   []repoFile{{dirPath: "a"}},
+			want: [][]repoFile{{{dirPath: "a"}}},
+		},
+		{
+			in: []repoFile{
+				{dirPath: "a", filename: "f1"},
+				{dirPath: "a", filename: "f2"},
+			},
+			want: [][]repoFile{{
+				{dirPath: "a", filename: "f1"},
+				{dirPath: "a", filename: "f2"},
+			}},
+		},
+		{
+			in: []repoFile{
+				{dirPath: "a", filename: "f1"},
+				{dirPath: "a", filename: "f2"},
+				{dirPath: "b", filename: "f1"},
+				{dirPath: "c", filename: "f1"},
+				{dirPath: "c", filename: "f2"},
+			},
+			want: [][]repoFile{
+				{
+					{dirPath: "a", filename: "f1"},
+					{dirPath: "a", filename: "f2"},
+				},
+				{
+					{dirPath: "b", filename: "f1"},
+				},
+				{
+					{dirPath: "c", filename: "f1"},
+					{dirPath: "c", filename: "f2"},
+				},
+			},
+		},
+	} {
+		got, err := groupFilesByDirectory(test.in)
+		if err != nil {
+			t.Fatalf("%v: %v", test.in, err)
+		}
+		if diff := cmp.Diff(got, test.want, cmp.AllowUnexported(repoFile{})); diff != "" {
+			t.Errorf("%v: (-want, +got)\n%s", test.in, diff)
+		}
+	}
+
+	_, err := groupFilesByDirectory([]repoFile{{dirPath: "a"}, {dirPath: "b"}, {dirPath: "a"}})
+	if err == nil {
+		t.Error("got nil, want error")
+	}
+}
+
 func readCVE(t *testing.T, repo *git.Repository, path string) (*cveschema.CVE, string) {
 	c := headCommit(t, repo)
 	file, err := c.File(path)
diff --git a/internal/worker/worker.go b/internal/worker/worker.go
index 0c53eb9..6042787 100644
--- a/internal/worker/worker.go
+++ b/internal/worker/worker.go
@@ -35,9 +35,10 @@
 			return err
 		}
 	}
-	return doUpdate(ctx, repo, ch, st, func(cve *cveschema.CVE) (bool, error) {
+	_, err = doUpdate(ctx, repo, ch, st, func(cve *cveschema.CVE) (bool, error) {
 		return TriageCVE(ctx, cve, pkgsiteURL)
 	})
+	return err
 }
 
 // checkUpdate performs sanity checks on a potential update.