internal/worker/store: minor changes
- Add commit time to CommitUpdateRecord. We'll use it to perform
a sanity check to make sure updates don't happen with older commits.
- Add a limit argument to ListCommitUpdateRecords.
- Also, use some functionality in the git package that I hadn't noticed
before, removing some of our code.
Change-Id: Icba4a11d6700ac71ab45923304be802353677c3d
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/368595
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/cmd/worker/main.go b/cmd/worker/main.go
index 0019fd7..530cc03 100644
--- a/cmd/worker/main.go
+++ b/cmd/worker/main.go
@@ -94,7 +94,7 @@
}
func listUpdatesCommand(ctx context.Context, st store.Store) error {
- recs, err := st.ListCommitUpdateRecords(ctx)
+ recs, err := st.ListCommitUpdateRecords(ctx, 0)
if err != nil {
return err
}
diff --git a/internal/worker/repo_test.go b/internal/worker/repo_test.go
index 04a958c..b1ae019 100644
--- a/internal/worker/repo_test.go
+++ b/internal/worker/repo_test.go
@@ -5,8 +5,6 @@
package worker
import (
- "io/ioutil"
- "path"
"testing"
"time"
@@ -87,53 +85,3 @@
}
return ref.Hash(), nil
}
-
-// findBlob returns the blob at filename in repo.
-// It fail the test if it doesn't exist.
-func findBlob(t *testing.T, repo *git.Repository, filename string) *object.Blob {
- c := headCommit(t, repo)
- tree, err := repo.TreeObject(c.TreeHash)
- if err != nil {
- t.Fatal(err)
- }
- e := findEntry(t, repo, tree, filename)
- blob, err := repo.BlobObject(e.Hash)
- if err != nil {
- t.Fatal(err)
- }
- return blob
-}
-
-// readBlob reads the contents of a blob.
-func readBlob(t *testing.T, blob *object.Blob) []byte {
- r, err := blob.Reader()
- if err != nil {
- t.Fatal(err)
- }
- data, err := ioutil.ReadAll(r)
- if err != nil {
- t.Fatal(err)
- }
- return data
-}
-
-// findEntry returns the TreeEntry at filename. It fails the test if
-// it doesn't exist.
-func findEntry(t *testing.T, repo *git.Repository, tree *object.Tree, filename string) object.TreeEntry {
- dir, base := path.Split(filename)
- if dir != "" {
- te := findEntry(t, repo, tree, dir[:len(dir)-1])
- var err error
- tree, err = repo.TreeObject(te.Hash)
- if err != nil {
- t.Fatal(err)
- }
- }
- for _, e := range tree.Entries {
- if e.Name == base {
- return e
- }
- }
- t.Fatalf("could not find %q in repo", filename)
- return object.TreeEntry{}
-}
diff --git a/internal/worker/store/fire_store.go b/internal/worker/store/fire_store.go
index 98940de..90078df 100644
--- a/internal/worker/store/fire_store.go
+++ b/internal/worker/store/fire_store.go
@@ -84,9 +84,13 @@
}
// ListCommitUpdateRecords implements Store.ListCommitUpdateRecords.
-func (fs *FireStore) ListCommitUpdateRecords(ctx context.Context) ([]*CommitUpdateRecord, error) {
+func (fs *FireStore) ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error) {
var urs []*CommitUpdateRecord
- iter := fs.nsDoc.Collection(updateCollection).Documents(ctx)
+ q := fs.nsDoc.Collection(updateCollection).Query
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+ iter := q.Documents(ctx)
for {
docsnap, err := iter.Next()
if err == iterator.Done {
diff --git a/internal/worker/store/mem_store.go b/internal/worker/store/mem_store.go
index 62282f1..2e8b1ac 100644
--- a/internal/worker/store/mem_store.go
+++ b/internal/worker/store/mem_store.go
@@ -62,7 +62,7 @@
}
// ListCommitUpdateRecords implements Store.ListCommitUpdateRecords.
-func (ms *MemStore) ListCommitUpdateRecords(context.Context) ([]*CommitUpdateRecord, error) {
+func (ms *MemStore) ListCommitUpdateRecords(_ context.Context, limit int) ([]*CommitUpdateRecord, error) {
var urs []*CommitUpdateRecord
for _, ur := range ms.updateRecords {
urs = append(urs, ur)
@@ -70,6 +70,9 @@
sort.Slice(urs, func(i, j int) bool {
return urs[i].StartedAt.After(urs[j].StartedAt)
})
+ if limit > 0 && len(urs) > limit {
+ urs = urs[:limit]
+ }
return urs, nil
}
diff --git a/internal/worker/store/store.go b/internal/worker/store/store.go
index 4ab8556..ab6f80a 100644
--- a/internal/worker/store/store.go
+++ b/internal/worker/store/store.go
@@ -103,6 +103,8 @@
StartedAt, EndedAt time.Time
// The repo commit hash that this update is working on.
CommitHash string
+ // The time the commit occurred.
+ CommitTime time.Time
// The total number of CVEs being processed in this update.
NumTotal int
// The number currently processed. When this equals NumTotal, the
@@ -129,9 +131,9 @@
// CreateCommitUpdateRecord, because it will have the correct ID.
SetCommitUpdateRecord(context.Context, *CommitUpdateRecord) error
- // ListCommitUpdateRecords returns all the CommitUpdateRecords in the store, from most to
+ // ListCommitUpdateRecords returns some the CommitUpdateRecords in the store, from most to
// least recent.
- ListCommitUpdateRecords(context.Context) ([]*CommitUpdateRecord, error)
+ ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error)
// RunTransaction runs the function in a transaction.
RunTransaction(context.Context, func(context.Context, Transaction) error) error
diff --git a/internal/worker/store/store_test.go b/internal/worker/store/store_test.go
index cc1dfb6..777c475 100644
--- a/internal/worker/store/store_test.go
+++ b/internal/worker/store/store_test.go
@@ -53,7 +53,7 @@
u2.NumAdded = 40
u2.NumModified = 40
must(t, s.SetCommitUpdateRecord(ctx, u2))
- got, err := s.ListCommitUpdateRecords(ctx)
+ got, err := s.ListCommitUpdateRecords(ctx, 0)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/worker/update.go b/internal/worker/update.go
index a89b9a8..7fe48ea 100644
--- a/internal/worker/update.go
+++ b/internal/worker/update.go
@@ -51,11 +51,16 @@
log.Info(ctx, "update starting", event.String("commit", commitHash.String()))
+ commit, err := repo.CommitObject(commitHash)
+ if err != nil {
+ return err
+ }
+
// Get all the CVE files.
// It is cheaper to read all the files from the repo and compare
// them to the DB in bulk, than to walk the repo and process
// each file individually.
- files, err := repoCVEFiles(repo, commitHash)
+ files, err := repoCVEFiles(repo, commit)
if err != nil {
return err
}
@@ -63,6 +68,7 @@
ur := &store.CommitUpdateRecord{
StartedAt: time.Now(),
CommitHash: commitHash.String(),
+ CommitTime: commit.Committer.When,
NumTotal: len(files),
}
if err := st.CreateCommitUpdateRecord(ctx, ur); err != nil {
@@ -236,13 +242,9 @@
// repoCVEFiles returns all the CVE files in the given repo commit, sorted by
// name.
-func repoCVEFiles(repo *git.Repository, commitHash plumbing.Hash) (_ []repoFile, err error) {
- defer derrors.Wrap(&err, "repoCVEFiles(%s)", commitHash)
+func repoCVEFiles(repo *git.Repository, commit *object.Commit) (_ []repoFile, err error) {
+ defer derrors.Wrap(&err, "repoCVEFiles(%s)", commit.Hash)
- commit, err := repo.CommitObject(commitHash)
- if err != nil {
- return nil, fmt.Errorf("CommitObject: %w", err)
- }
root, err := repo.TreeObject(commit.TreeHash)
if err != nil {
return nil, fmt.Errorf("TreeObject: %v", err)
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index 514b9f2..fa490c9 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -24,11 +24,11 @@
if err != nil {
t.Fatal(err)
}
- h, err := headHash(repo)
+ commit := headCommit(t, repo)
if err != nil {
t.Fatal(err)
}
- got, err := repoCVEFiles(repo, h)
+ got, err := repoCVEFiles(repo, commit)
if err != nil {
t.Fatal(err)
}
@@ -62,6 +62,7 @@
if err != nil {
t.Fatal(err)
}
+
commitHash := ref.Hash().String()
const (
path1 = "2021/0xxx/CVE-2021-0001.json"
@@ -179,12 +180,20 @@
}
func readCVE(t *testing.T, repo *git.Repository, path string) (*cveschema.CVE, string) {
- blob := findBlob(t, repo, path)
- var cve cveschema.CVE
- if err := json.Unmarshal(readBlob(t, blob), &cve); err != nil {
+ c := headCommit(t, repo)
+ file, err := c.File(path)
+ if err != nil {
t.Fatal(err)
}
- return &cve, blob.Hash.String()
+ var cve cveschema.CVE
+ r, err := file.Reader()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := json.NewDecoder(r).Decode(&cve); err != nil {
+ t.Fatal(err)
+ }
+ return &cve, file.Hash.String()
}
func createCVERecords(t *testing.T, s store.Store, crs []*store.CVERecord) {