blob: 56ccc45d1aca6b4fba7892e484fa38c9fe86c351 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package store
import (
"context"
"errors"
"fmt"
"math/rand"
"sort"
"sync"
"time"
)
// MemStore is an in-memory implementation of Store, for testing.
type MemStore struct {
mu sync.Mutex
cveRecords map[string]*CVERecord
updateRecords map[string]*CommitUpdateRecord
dirHashes map[string]string
ghsaRecords map[string]*GHSARecord
modScanRecords []*ModuleScanRecord
}
// NewMemStore creates a new, empty MemStore.
func NewMemStore() *MemStore {
m := &MemStore{}
_ = m.Clear(context.Background())
return m
}
// Clear removes all data from the MemStore.
func (ms *MemStore) Clear(context.Context) error {
ms.cveRecords = map[string]*CVERecord{}
ms.updateRecords = map[string]*CommitUpdateRecord{}
ms.dirHashes = map[string]string{}
ms.ghsaRecords = map[string]*GHSARecord{}
ms.modScanRecords = nil
return nil
}
// CVERecords return all the CVERecords of the store.
func (ms *MemStore) CVERecords() map[string]*CVERecord {
return ms.cveRecords
}
// CreateCommitUpdateRecord implements Store.CreateCommitUpdateRecord.
func (ms *MemStore) CreateCommitUpdateRecord(ctx context.Context, r *CommitUpdateRecord) error {
r.ID = fmt.Sprint(rand.Uint32())
if ms.updateRecords[r.ID] != nil {
panic("duplicate ID")
}
r.UpdatedAt = time.Now()
return ms.SetCommitUpdateRecord(ctx, r)
}
// SetCommitUpdateRecord implements Store.SetCommitUpdateRecord.
func (ms *MemStore) SetCommitUpdateRecord(_ context.Context, r *CommitUpdateRecord) error {
if r.ID == "" {
return errors.New("SetCommitUpdateRecord: need ID")
}
c := *r
c.UpdatedAt = time.Now()
ms.updateRecords[c.ID] = &c
return nil
}
// ListCommitUpdateRecords implements Store.ListCommitUpdateRecords.
func (ms *MemStore) ListCommitUpdateRecords(_ context.Context, limit int) ([]*CommitUpdateRecord, error) {
var urs []*CommitUpdateRecord
for _, ur := range ms.updateRecords {
urs = append(urs, ur)
}
sort.Slice(urs, func(i, j int) bool {
return urs[i].StartedAt.After(urs[j].StartedAt)
})
if limit > 0 && len(urs) > limit {
urs = urs[:limit]
}
return urs, nil
}
// GetCVERecord implements store.GetCVERecord.
func (ms *MemStore) GetCVERecord(ctx context.Context, id string) (*CVERecord, error) {
return ms.cveRecords[id], nil
}
// ListCVERecordsWithTriageState implements Store.ListCVERecordsWithTriageState.
func (ms *MemStore) ListCVERecordsWithTriageState(_ context.Context, ts TriageState) ([]*CVERecord, error) {
var crs []*CVERecord
for _, r := range ms.cveRecords {
if r.TriageState == ts {
crs = append(crs, r)
}
}
sort.Slice(crs, func(i, j int) bool {
return crs[i].ID < crs[j].ID
})
return crs, nil
}
// CreateModuleScanRecord implements Store.CreateModuleScanRecord.
func (ms *MemStore) CreateModuleScanRecord(_ context.Context, r *ModuleScanRecord) error {
if err := r.Validate(); err != nil {
return err
}
ms.modScanRecords = append(ms.modScanRecords, r)
return nil
}
// GetModuleScanRecord implements store.GetModuleScanRecord.
func (ms *MemStore) GetModuleScanRecord(_ context.Context, path, version string, dbTime time.Time) (*ModuleScanRecord, error) {
var m *ModuleScanRecord
for _, r := range ms.modScanRecords {
if r.Path == path && r.Version == version && r.DBTime.Equal(dbTime) {
if m == nil || m.FinishedAt.Before(r.FinishedAt) {
m = r
}
}
}
return m, nil
}
// ListModuleScanRecords implements Store.ListModuleScanRecords.
func (ms *MemStore) ListModuleScanRecords(ctx context.Context, limit int) ([]*ModuleScanRecord, error) {
rs := make([]*ModuleScanRecord, len(ms.modScanRecords))
copy(rs, ms.modScanRecords)
sort.Slice(rs, func(i, j int) bool { return rs[i].FinishedAt.After(rs[j].FinishedAt) })
if limit == 0 || limit >= len(rs) {
return rs, nil
}
return rs[:limit], nil
}
// GetDirectoryHash implements Transaction.GetDirectoryHash.
func (ms *MemStore) GetDirectoryHash(_ context.Context, dir string) (string, error) {
return ms.dirHashes[dir], nil
}
// SetDirectoryHash implements Transaction.SetDirectoryHash.
func (ms *MemStore) SetDirectoryHash(_ context.Context, dir, hash string) error {
ms.dirHashes[dir] = hash
return nil
}
// RunTransaction implements Store.RunTransaction.
// A transaction runs with a single lock on the entire DB.
func (ms *MemStore) RunTransaction(ctx context.Context, f func(context.Context, Transaction) error) error {
tx := &memTransaction{ms}
ms.mu.Lock()
defer ms.mu.Unlock()
return f(ctx, tx)
}
// memTransaction implements Store.Transaction.
type memTransaction struct {
ms *MemStore
}
// CreateCVERecord implements Transaction.CreateCVERecord.
func (tx *memTransaction) CreateCVERecord(r *CVERecord) error {
if err := r.Validate(); err != nil {
return err
}
tx.ms.cveRecords[r.ID] = r
return nil
}
// SetCVERecord implements Transaction.SetCVERecord.
func (tx *memTransaction) SetCVERecord(r *CVERecord) error {
if err := r.Validate(); err != nil {
return err
}
if tx.ms.cveRecords[r.ID] == nil {
return fmt.Errorf("CVERecord with ID %q not found", r.ID)
}
tx.ms.cveRecords[r.ID] = r
return nil
}
// GetCVERecords implements Transaction.GetCVERecords.
func (tx *memTransaction) GetCVERecords(startID, endID string) ([]*CVERecord, error) {
var crs []*CVERecord
for id, r := range tx.ms.cveRecords {
if id >= startID && id <= endID {
c := *r
crs = append(crs, &c)
}
}
// Sort for testing.
sort.Slice(crs, func(i, j int) bool {
return crs[i].ID < crs[j].ID
})
return crs, nil
}
// CreateGHSARecord implements Transaction.CreateGHSARecord.
func (tx *memTransaction) CreateGHSARecord(r *GHSARecord) error {
if _, ok := tx.ms.ghsaRecords[r.GHSA.ID]; ok {
return fmt.Errorf("GHSARecord %s already exists", r.GHSA.ID)
}
tx.ms.ghsaRecords[r.GHSA.ID] = r
return nil
}
// SetGHSARecord implements Transaction.SetGHSARecord.
func (tx *memTransaction) SetGHSARecord(r *GHSARecord) error {
if _, ok := tx.ms.ghsaRecords[r.GHSA.ID]; !ok {
return fmt.Errorf("GHSARecord %s does not exist", r.GHSA.ID)
}
tx.ms.ghsaRecords[r.GHSA.ID] = r
return nil
}
// GetGHSARecord implements Transaction.GetGHSARecord.
func (tx *memTransaction) GetGHSARecord(id string) (*GHSARecord, error) {
if r, ok := tx.ms.ghsaRecords[id]; ok {
return r, nil
}
return nil, nil
}
// GetGHSARecords implements Transaction.GetGHSARecords.
func (tx *memTransaction) GetGHSARecords() ([]*GHSARecord, error) {
var recs []*GHSARecord
for _, r := range tx.ms.ghsaRecords {
recs = append(recs, r)
}
return recs, nil
}