internal/worker/store: support module scanning
Add a ModuleScanRecord type and code to manipulate it in the DB.
Each scanned module results in a ModuleScanRecord.
Change-Id: Icc3facea31980cf14c63b594d6d730a79ce27e62
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/393694
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/worker/store/fire_store.go b/internal/worker/store/fire_store.go
index a4c20ff..5017b95 100644
--- a/internal/worker/store/fire_store.go
+++ b/internal/worker/store/fire_store.go
@@ -9,6 +9,7 @@
"errors"
"fmt"
"strings"
+ "time"
"cloud.google.com/go/firestore"
"golang.org/x/vulndb/internal/derrors"
@@ -32,6 +33,7 @@
// - CommitUpdates for CommitUpdateRecords
// - DirHashes for directory hashes
// - GHSAs for GHSARecords.
+// - ModuleScans for ModuleScanRecords.
type FireStore struct {
namespace string
client *firestore.Client
@@ -44,6 +46,7 @@
cveCollection = "CVEs"
dirHashCollection = "DirHashes"
ghsaCollection = "GHSAs"
+ modScanCollection = "ModuleScans"
)
// NewFireStore creates a new FireStore, backed by a client to Firestore. Since
@@ -124,20 +127,18 @@
q = q.Limit(limit)
}
iter := q.Documents(ctx)
- for {
- docsnap, err := iter.Next()
- if err == iterator.Done {
- break
- }
- if err != nil {
- return nil, err
- }
+ defer iter.Stop()
+ err := apply(iter, func(ds *firestore.DocumentSnapshot) error {
var ur CommitUpdateRecord
- if err := docsnap.DataTo(&ur); err != nil {
- return nil, err
+ if err := ds.DataTo(&ur); err != nil {
+ return err
}
- ur.ID = docsnap.Ref.ID
+ ur.ID = ds.Ref.ID
urs = append(urs, &ur)
+ return nil
+ })
+ if err != nil {
+ return nil, err
}
return urs, nil
}
@@ -158,6 +159,62 @@
return docsnapsToCVERecords(docsnaps)
}
+// CreateModuleScanRecord implements Store.CreateModuleScanRecord.
+func (fs *FireStore) CreateModuleScanRecord(ctx context.Context, r *ModuleScanRecord) error {
+ if err := r.Validate(); err != nil {
+ return err
+ }
+ docref := fs.nsDoc.Collection(modScanCollection).NewDoc()
+ _, err := docref.Create(ctx, r)
+ return err
+}
+
+// GetModuleScanRecord implements store.GetModuleScanRecord.
+func (fs *FireStore) GetModuleScanRecord(ctx context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error) {
+ // There may be several, but we only need one; take the most recent.
+ q := fs.nsDoc.Collection(modScanCollection).
+ Where("Path", "==", path).
+ Where("Version", "==", version).
+ Where("DBTime", "==", dbTime).
+ OrderBy("FinishedAt", firestore.Desc)
+ docsnaps, err := q.Documents(ctx).GetAll()
+ if err != nil {
+ return nil, err
+ }
+ if len(docsnaps) == 0 {
+ return nil, nil
+ }
+
+ var r ModuleScanRecord
+ if err := docsnaps[0].DataTo(&r); err != nil {
+ return nil, err
+ }
+ return &r, nil
+}
+
+// ListModuleScanRecords implements Store.ListModuleScanRecords.
+func (fs *FireStore) ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error) {
+ q := fs.nsDoc.Collection(modScanCollection).OrderBy("FinishedAt", firestore.Desc)
+ if limit > 0 {
+ q = q.Limit(limit)
+ }
+ var rs []*ModuleScanRecord
+ iter := q.Documents(ctx)
+ defer iter.Stop()
+ err := apply(iter, func(ds *firestore.DocumentSnapshot) error {
+ var r ModuleScanRecord
+ if err := ds.DataTo(&r); err != nil {
+ return err
+ }
+ rs = append(rs, &r)
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return rs, nil
+}
+
// dirHashRef returns a DocumentRef for the directory dir.
func (s *FireStore) dirHashRef(dir string) *firestore.DocumentRef {
// Firestore IDs cannot contain slashes.
@@ -369,3 +426,20 @@
}
}
}
+
+// apply calls f for each element of iter. If f returns an error, apply stops
+// immediately and returns the same error.
+func apply(iter *firestore.DocumentIterator, f func(*firestore.DocumentSnapshot) error) error {
+ for {
+ docsnap, err := iter.Next()
+ if err == iterator.Done {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ if err := f(docsnap); err != nil {
+ return err
+ }
+ }
+}
diff --git a/internal/worker/store/mem_store.go b/internal/worker/store/mem_store.go
index 283c092..5029094 100644
--- a/internal/worker/store/mem_store.go
+++ b/internal/worker/store/mem_store.go
@@ -16,11 +16,12 @@
// MemStore is an in-memory implementation of Store, for testing.
type MemStore struct {
- mu sync.Mutex
- cveRecords map[string]*CVERecord
- updateRecords map[string]*CommitUpdateRecord
- dirHashes map[string]string
- ghsaRecords map[string]*GHSARecord
+ mu sync.Mutex
+ cveRecords map[string]*CVERecord
+ updateRecords map[string]*CommitUpdateRecord
+ dirHashes map[string]string
+ ghsaRecords map[string]*GHSARecord
+ modScanRecords []*ModuleScanRecord
}
// NewMemStore creates a new, empty MemStore.
@@ -36,6 +37,7 @@
ms.updateRecords = map[string]*CommitUpdateRecord{}
ms.dirHashes = map[string]string{}
ms.ghsaRecords = map[string]*GHSARecord{}
+ ms.modScanRecords = nil
return nil
}
@@ -99,6 +101,39 @@
return crs, nil
}
+// CreateModuleScanRecord implements Store.CreateModuleScanRecord.
+func (ms *MemStore) CreateModuleScanRecord(_ context.Context, r *ModuleScanRecord) error {
+ if err := r.Validate(); err != nil {
+ return err
+ }
+ ms.modScanRecords = append(ms.modScanRecords, r)
+ return nil
+}
+
+// GetModuleScanRecord implements store.GetModuleScanRecord.
+func (ms *MemStore) GetModuleScanRecord(_ context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error) {
+ var m *ModuleScanRecord
+ for _, r := range ms.modScanRecords {
+ if r.Path == path && r.Version == version && r.DBTime.Equal(dbTime) {
+ if m == nil || m.FinishedAt.Before(r.FinishedAt) {
+ m = r
+ }
+ }
+ }
+ return m, nil
+}
+
+// ListModuleScanRecords implements Store.ListModuleScanRecords.
+func (ms *MemStore) ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error) {
+ rs := make([]*ModuleScanRecord, len(ms.modScanRecords))
+ copy(rs, ms.modScanRecords)
+ sort.Slice(rs, func(i, j int) bool { return rs[i].FinishedAt.After(rs[j].FinishedAt) })
+ if limit == 0 || limit >= len(rs) {
+ return rs, nil
+ }
+ return rs[:limit], nil
+}
+
// GetDirectoryHash implements Transaction.GetDirectoryHash.
func (ms *MemStore) GetDirectoryHash(_ context.Context, dir string) (string, error) {
return ms.dirHashes[dir], nil
diff --git a/internal/worker/store/store.go b/internal/worker/store/store.go
index 9f87c88..624b929 100644
--- a/internal/worker/store/store.go
+++ b/internal/worker/store/store.go
@@ -201,6 +201,33 @@
func (r *GHSARecord) GetPrettyID() string { return r.GHSA.PrettyID() }
+// A ModuleScanRecord holds information about a vulnerability scan of a module.
+type ModuleScanRecord struct {
+ Path string
+ Version string
+ DBTime time.Time // last-modified time of the vuln DB
+ Error string // if non-empty, error while scanning
+ VulnIDs []string
+ FinishedAt time.Time // when the scan completed (successfully or not)
+}
+
+// Validate returns an error if the ModuleScanRecord is not valid.
+func (r *ModuleScanRecord) Validate() error {
+ if r.Path == "" {
+ return errors.New("need Path")
+ }
+ if r.Version == "" {
+ return errors.New("need Version")
+ }
+ if r.DBTime.IsZero() {
+ return errors.New("need DBTime")
+ }
+ if r.FinishedAt.IsZero() {
+ return errors.New("need FinishedAt")
+ }
+ return nil
+}
+
// A Store is a storage system for the CVE database.
type Store interface {
// CreateCommitUpdateRecord creates a new CommitUpdateRecord. It should be called at the start
@@ -212,7 +239,7 @@
// CreateCommitUpdateRecord, because it will have the correct ID.
SetCommitUpdateRecord(context.Context, *CommitUpdateRecord) error
- // ListCommitUpdateRecords returns some the CommitUpdateRecords in the store, from most to
+ // ListCommitUpdateRecords returns some of the CommitUpdateRecords in the store, from most to
// least recent.
ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error)
@@ -230,6 +257,18 @@
// SetDirectoryHash sets the hash for the given directory.
SetDirectoryHash(ctx context.Context, dir, hash string) error
+ // CreateModuleScanRecord adds a ModuleScanRecord to the DB.
+ CreateModuleScanRecord(context.Context, *ModuleScanRecord) error
+
+ // GetModuleScanRecord returns the most recent ModuleScanRecord matching the
+ // given module path, version and DB time. If not found, it returns (nil,
+ // nil).
+ GetModuleScanRecord(ctx context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error)
+
+ // ListModuleScanRecords returns some of the ModuleScanRecords in the store
+ // from most to least recent. If limit is zero, all records are returned.
+ ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error)
+
// RunTransaction runs the function in a transaction.
RunTransaction(context.Context, func(context.Context, Transaction) error) error
}
diff --git a/internal/worker/store/store_test.go b/internal/worker/store/store_test.go
index d21b8e2..3344287 100644
--- a/internal/worker/store/store_test.go
+++ b/internal/worker/store/store_test.go
@@ -39,6 +39,9 @@
t.Run("GHSAs", func(t *testing.T) {
testGHSAs(t, s)
})
+ t.Run("ModuleScanRecords", func(t *testing.T) {
+ testModuleScanRecords(t, s)
+ })
}
func testUpdates(t *testing.T, s Store) {
@@ -264,6 +267,62 @@
}
}
+func testModuleScanRecords(t *testing.T, s Store) {
+ ctx := context.Background()
+ tm := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)
+ rs := []*ModuleScanRecord{
+ {
+ Path: "m1",
+ Version: "v1.2.3",
+ DBTime: tm,
+ FinishedAt: tm,
+ },
+ {
+ Path: "m1",
+ Version: "v1.2.3",
+ DBTime: tm,
+ FinishedAt: tm.Add(time.Hour),
+ },
+ {
+ Path: "m2",
+ Version: "v1.2.3",
+ DBTime: tm,
+ FinishedAt: tm.Add(time.Hour * 2),
+ },
+ }
+ for _, r := range rs {
+ if err := s.CreateModuleScanRecord(ctx, r); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // GetModuleScanRecord
+ got, err := s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Expect the most recent.
+ if want := rs[1]; !cmp.Equal(got, want) {
+ t.Errorf("got\n%+v\nwant\n%+v", got, want)
+ }
+ // Non-existent record.
+ got, err = s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm.Add(time.Second))
+ if got != nil || err != nil {
+ t.Errorf("got (%v, %v), want (nil, nil)", got, err)
+ }
+
+ // ListModuleScanRecords
+ got2, err := s.ListModuleScanRecords(ctx, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ want := []*ModuleScanRecord{rs[2], rs[1], rs[0]}
+ if !cmp.Equal(got2, want) {
+ t.Errorf("got\n%+v\nwant\n%+v", got2, want)
+ }
+}
+
func createCVERecords(t *testing.T, ctx context.Context, s Store, crs []*CVERecord) {
err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
for _, cr := range crs {