internal/worker/store: support module scanning

Add a ModuleScanRecord type and code to manipulate it in the DB.
Each scanned module results in a ModuleScanRecord.

Change-Id: Icc3facea31980cf14c63b594d6d730a79ce27e62
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/393694
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/worker/store/fire_store.go b/internal/worker/store/fire_store.go
index a4c20ff..5017b95 100644
--- a/internal/worker/store/fire_store.go
+++ b/internal/worker/store/fire_store.go
@@ -9,6 +9,7 @@
 	"errors"
 	"fmt"
 	"strings"
+	"time"
 
 	"cloud.google.com/go/firestore"
 	"golang.org/x/vulndb/internal/derrors"
@@ -32,6 +33,7 @@
 // - CommitUpdates for CommitUpdateRecords
 // - DirHashes for directory hashes
 // - GHSAs for GHSARecords.
+// - ModuleScans for ModuleScanRecords.
 type FireStore struct {
 	namespace string
 	client    *firestore.Client
@@ -44,6 +46,7 @@
 	cveCollection       = "CVEs"
 	dirHashCollection   = "DirHashes"
 	ghsaCollection      = "GHSAs"
+	modScanCollection   = "ModuleScans"
 )
 
 // NewFireStore creates a new FireStore, backed by a client to Firestore. Since
@@ -124,20 +127,18 @@
 		q = q.Limit(limit)
 	}
 	iter := q.Documents(ctx)
-	for {
-		docsnap, err := iter.Next()
-		if err == iterator.Done {
-			break
-		}
-		if err != nil {
-			return nil, err
-		}
+	defer iter.Stop()
+	err := apply(iter, func(ds *firestore.DocumentSnapshot) error {
 		var ur CommitUpdateRecord
-		if err := docsnap.DataTo(&ur); err != nil {
-			return nil, err
+		if err := ds.DataTo(&ur); err != nil {
+			return err
 		}
-		ur.ID = docsnap.Ref.ID
+		ur.ID = ds.Ref.ID
 		urs = append(urs, &ur)
+		return nil
+	})
+	if err != nil {
+		return nil, err
 	}
 	return urs, nil
 }
@@ -158,6 +159,62 @@
 	return docsnapsToCVERecords(docsnaps)
 }
 
+// CreateModuleScanRecord implements Store.CreateModuleScanRecord.
+func (fs *FireStore) CreateModuleScanRecord(ctx context.Context, r *ModuleScanRecord) error {
+	if err := r.Validate(); err != nil {
+		return err
+	}
+	docref := fs.nsDoc.Collection(modScanCollection).NewDoc()
+	_, err := docref.Create(ctx, r)
+	return err
+}
+
+// GetModuleScanRecord implements store.GetModuleScanRecord.
+func (fs *FireStore) GetModuleScanRecord(ctx context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error) {
+	// There may be several, but we only need one; take the most recent.
+	q := fs.nsDoc.Collection(modScanCollection).
+		Where("Path", "==", path).
+		Where("Version", "==", version).
+		Where("DBTime", "==", dbTime).
+		OrderBy("FinishedAt", firestore.Desc)
+	docsnaps, err := q.Documents(ctx).GetAll()
+	if err != nil {
+		return nil, err
+	}
+	if len(docsnaps) == 0 {
+		return nil, nil
+	}
+
+	var r ModuleScanRecord
+	if err := docsnaps[0].DataTo(&r); err != nil {
+		return nil, err
+	}
+	return &r, nil
+}
+
+// ListModuleScanRecords implements Store.ListModuleScanRecords.
+func (fs *FireStore) ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error) {
+	q := fs.nsDoc.Collection(modScanCollection).OrderBy("FinishedAt", firestore.Desc)
+	if limit > 0 {
+		q = q.Limit(limit)
+	}
+	var rs []*ModuleScanRecord
+	iter := q.Documents(ctx)
+	defer iter.Stop()
+	err := apply(iter, func(ds *firestore.DocumentSnapshot) error {
+		var r ModuleScanRecord
+		if err := ds.DataTo(&r); err != nil {
+			return err
+		}
+		rs = append(rs, &r)
+		return nil
+	})
+	if err != nil {
+		return nil, err
+	}
+	return rs, nil
+}
+
 // dirHashRef returns a DocumentRef for the directory dir.
 func (s *FireStore) dirHashRef(dir string) *firestore.DocumentRef {
 	// Firestore IDs cannot contain slashes.
@@ -369,3 +426,20 @@
 		}
 	}
 }
+
+// apply calls f for each element of iter. If f returns an error, apply stops
+// immediately and returns the same error.
+func apply(iter *firestore.DocumentIterator, f func(*firestore.DocumentSnapshot) error) error {
+	for {
+		docsnap, err := iter.Next()
+		if err == iterator.Done {
+			return nil
+		}
+		if err != nil {
+			return err
+		}
+		if err := f(docsnap); err != nil {
+			return err
+		}
+	}
+}
diff --git a/internal/worker/store/mem_store.go b/internal/worker/store/mem_store.go
index 283c092..5029094 100644
--- a/internal/worker/store/mem_store.go
+++ b/internal/worker/store/mem_store.go
@@ -16,11 +16,12 @@
 
 // MemStore is an in-memory implementation of Store, for testing.
 type MemStore struct {
-	mu            sync.Mutex
-	cveRecords    map[string]*CVERecord
-	updateRecords map[string]*CommitUpdateRecord
-	dirHashes     map[string]string
-	ghsaRecords   map[string]*GHSARecord
+	mu             sync.Mutex
+	cveRecords     map[string]*CVERecord
+	updateRecords  map[string]*CommitUpdateRecord
+	dirHashes      map[string]string
+	ghsaRecords    map[string]*GHSARecord
+	modScanRecords []*ModuleScanRecord
 }
 
 // NewMemStore creates a new, empty MemStore.
@@ -36,6 +37,7 @@
 	ms.updateRecords = map[string]*CommitUpdateRecord{}
 	ms.dirHashes = map[string]string{}
 	ms.ghsaRecords = map[string]*GHSARecord{}
+	ms.modScanRecords = nil
 	return nil
 }
 
@@ -99,6 +101,39 @@
 	return crs, nil
 }
 
+// CreateModuleScanRecord implements Store.CreateModuleScanRecord.
+func (ms *MemStore) CreateModuleScanRecord(_ context.Context, r *ModuleScanRecord) error {
+	if err := r.Validate(); err != nil {
+		return err
+	}
+	ms.modScanRecords = append(ms.modScanRecords, r)
+	return nil
+}
+
+// GetModuleScanRecord implements store.GetModuleScanRecord.
+func (ms *MemStore) GetModuleScanRecord(_ context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error) {
+	var m *ModuleScanRecord
+	for _, r := range ms.modScanRecords {
+		if r.Path == path && r.Version == version && r.DBTime.Equal(dbTime) {
+			if m == nil || m.FinishedAt.Before(r.FinishedAt) {
+				m = r
+			}
+		}
+	}
+	return m, nil
+}
+
+// ListModuleScanRecords implements Store.ListModuleScanRecords.
+func (ms *MemStore) ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error) {
+	rs := make([]*ModuleScanRecord, len(ms.modScanRecords))
+	copy(rs, ms.modScanRecords)
+	sort.Slice(rs, func(i, j int) bool { return rs[i].FinishedAt.After(rs[j].FinishedAt) })
+	if limit == 0 || limit >= len(rs) {
+		return rs, nil
+	}
+	return rs[:limit], nil
+}
+
 // GetDirectoryHash implements Transaction.GetDirectoryHash.
 func (ms *MemStore) GetDirectoryHash(_ context.Context, dir string) (string, error) {
 	return ms.dirHashes[dir], nil
diff --git a/internal/worker/store/store.go b/internal/worker/store/store.go
index 9f87c88..624b929 100644
--- a/internal/worker/store/store.go
+++ b/internal/worker/store/store.go
@@ -201,6 +201,33 @@
 
 func (r *GHSARecord) GetPrettyID() string { return r.GHSA.PrettyID() }
 
+// A ModuleScanRecord holds information about a vulnerability scan of a module.
+type ModuleScanRecord struct {
+	Path       string
+	Version    string
+	DBTime     time.Time // last-modified time of the vuln DB
+	Error      string    // if non-empty, error while scanning
+	VulnIDs    []string
+	FinishedAt time.Time // when the scan completed (successfully or not)
+}
+
+// Validate returns an error if the ModuleScanRecord is not valid.
+func (r *ModuleScanRecord) Validate() error {
+	if r.Path == "" {
+		return errors.New("need Path")
+	}
+	if r.Version == "" {
+		return errors.New("need Version")
+	}
+	if r.DBTime.IsZero() {
+		return errors.New("need DBTime")
+	}
+	if r.FinishedAt.IsZero() {
+		return errors.New("need FinishedAt")
+	}
+	return nil
+}
+
 // A Store is a storage system for the CVE database.
 type Store interface {
 	// CreateCommitUpdateRecord creates a new CommitUpdateRecord. It should be called at the start
@@ -212,7 +239,7 @@
 	// CreateCommitUpdateRecord, because it will have the correct ID.
 	SetCommitUpdateRecord(context.Context, *CommitUpdateRecord) error
 
-	// ListCommitUpdateRecords returns some the CommitUpdateRecords in the store, from most to
+	// ListCommitUpdateRecords returns some of the CommitUpdateRecords in the store, from most to
 	// least recent.
 	ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error)
 
@@ -230,6 +257,18 @@
 	// SetDirectoryHash sets the hash for the given directory.
 	SetDirectoryHash(ctx context.Context, dir, hash string) error
 
+	// CreateModuleScanRecord adds a ModuleScanRecord to the DB.
+	CreateModuleScanRecord(context.Context, *ModuleScanRecord) error
+
+	// GetModuleScanRecord returns the most recent ModuleScanRecord matching the
+	// given module path, version and DB time. If not found, it returns (nil,
+	// nil).
+	GetModuleScanRecord(ctx context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error)
+
+	// ListModuleScanRecords returns some of the ModuleScanRecords in the store
+	// from most to least recent. If limit is zero, all records are returned.
+	ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error)
+
 	// RunTransaction runs the function in a transaction.
 	RunTransaction(context.Context, func(context.Context, Transaction) error) error
 }
diff --git a/internal/worker/store/store_test.go b/internal/worker/store/store_test.go
index d21b8e2..3344287 100644
--- a/internal/worker/store/store_test.go
+++ b/internal/worker/store/store_test.go
@@ -39,6 +39,9 @@
 	t.Run("GHSAs", func(t *testing.T) {
 		testGHSAs(t, s)
 	})
+	t.Run("ModuleScanRecords", func(t *testing.T) {
+		testModuleScanRecords(t, s)
+	})
 }
 
 func testUpdates(t *testing.T, s Store) {
@@ -264,6 +267,62 @@
 	}
 }
 
+func testModuleScanRecords(t *testing.T, s Store) {
+	ctx := context.Background()
+	tm := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)
+	rs := []*ModuleScanRecord{
+		{
+			Path:       "m1",
+			Version:    "v1.2.3",
+			DBTime:     tm,
+			FinishedAt: tm,
+		},
+		{
+			Path:       "m1",
+			Version:    "v1.2.3",
+			DBTime:     tm,
+			FinishedAt: tm.Add(time.Hour),
+		},
+		{
+			Path:       "m2",
+			Version:    "v1.2.3",
+			DBTime:     tm,
+			FinishedAt: tm.Add(time.Hour * 2),
+		},
+	}
+	for _, r := range rs {
+		if err := s.CreateModuleScanRecord(ctx, r); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	// GetModuleScanRecord
+	got, err := s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Expect the most recent.
+	if want := rs[1]; !cmp.Equal(got, want) {
+		t.Errorf("got\n%+v\nwant\n%+v", got, want)
+	}
+	// Non-existent record.
+	got, err = s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm.Add(time.Second))
+	if got != nil || err != nil {
+		t.Errorf("got (%v, %v), want (nil, nil)", got, err)
+	}
+
+	// ListModuleScanRecords
+	got2, err := s.ListModuleScanRecords(ctx, 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	want := []*ModuleScanRecord{rs[2], rs[1], rs[0]}
+	if !cmp.Equal(got2, want) {
+		t.Errorf("got\n%+v\nwant\n%+v", got2, want)
+	}
+}
+
 func createCVERecords(t *testing.T, ctx context.Context, s Store, crs []*CVERecord) {
 	err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		for _, cr := range crs {