cmd/govulncheck: adds basic support for cache thread safety

Change-Id: Ifb79abf2f863787d37c152de9c668138597c1dc7
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/396594
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/cmd/govulncheck/cache.go b/cmd/govulncheck/cache.go
index 72feb24..26dd00f 100644
--- a/cmd/govulncheck/cache.go
+++ b/cmd/govulncheck/cache.go
@@ -13,18 +13,13 @@
 	"io/ioutil"
 	"os"
 	"path/filepath"
+	"sync"
 	"time"
 
 	"golang.org/x/vuln/client"
 	"golang.org/x/vuln/osv"
 )
 
-// NOTE: this cache implementation should be kept internal to the go tooling
-// (i.e. cmd/go/internal/something) so that the vulndb cache is owned by the
-// go command. Also it is currently NOT CONCURRENCY SAFE since it does not
-// implement file locking. If ported to the stdlib it should use
-// cmd/go/internal/lockedfile.
-
 // The cache uses a single JSON index file for each vulnerability database
 // which contains the map from packages to the time the last
 // vulnerability for that package was added/modified and the time that
@@ -43,9 +38,11 @@
 // $GOPATH/pkg/mod/cache/download/vulndb/{db hostname}/{import path}/vulns.json
 //   []*osv.Entry
 
-// fsCache is file-system cache implementing osv.Cache
-// TODO: make cache thread-safe
+// fsCache is a thread-safe file-system cache implementing osv.Cache
+//
+// TODO: use something like cmd/go/internal/lockedfile for thread safety?
 type fsCache struct {
+	mu      sync.Mutex
 	rootDir string
 }
 
@@ -62,6 +59,9 @@
 }
 
 func (c *fsCache) ReadIndex(dbName string) (client.DBIndex, time.Time, error) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	b, err := ioutil.ReadFile(filepath.Join(c.rootDir, dbName, "index.json"))
 	if err != nil {
 		if os.IsNotExist(err) {
@@ -77,6 +77,9 @@
 }
 
 func (c *fsCache) WriteIndex(dbName string, index client.DBIndex, retrieved time.Time) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	path := filepath.Join(c.rootDir, dbName)
 	if err := os.MkdirAll(path, 0755); err != nil {
 		return err
@@ -95,6 +98,9 @@
 }
 
 func (c *fsCache) ReadEntries(dbName string, p string) ([]*osv.Entry, error) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	b, err := ioutil.ReadFile(filepath.Join(c.rootDir, dbName, p, "vulns.json"))
 	if err != nil {
 		if os.IsNotExist(err) {
@@ -110,6 +116,9 @@
 }
 
 func (c *fsCache) WriteEntries(dbName string, p string, entries []*osv.Entry) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	path := filepath.Join(c.rootDir, dbName, p)
 	if err := os.MkdirAll(path, 0777); err != nil {
 		return err
diff --git a/cmd/govulncheck/cache_test.go b/cmd/govulncheck/cache_test.go
index be53956..09c8f11 100644
--- a/cmd/govulncheck/cache_test.go
+++ b/cmd/govulncheck/cache_test.go
@@ -8,12 +8,14 @@
 package main
 
 import (
+	"fmt"
 	"os"
 	"path/filepath"
 	"reflect"
 	"testing"
 	"time"
 
+	"golang.org/x/sync/errgroup"
 	"golang.org/x/vuln/client"
 	"golang.org/x/vuln/osv"
 )
@@ -79,3 +81,85 @@
 		t.Errorf("ReadEntries returned unexpected entries, got:\n%v\nwant:\n%v", entries, expectedEntries)
 	}
 }
+
+func TestConcurrency(t *testing.T) {
+	tmpDir := t.TempDir()
+
+	cache := &fsCache{rootDir: tmpDir}
+	dbName := "vulndb.golang.org"
+
+	g := new(errgroup.Group)
+	for i := 0; i < 1000; i++ {
+		i := i
+		g.Go(func() error {
+			id := i % 5
+			p := fmt.Sprintf("package%d", id)
+
+			entries, err := cache.ReadEntries(dbName, p)
+			if err != nil {
+				return err
+			}
+
+			err = cache.WriteEntries(dbName, p, append(entries, &osv.Entry{ID: fmt.Sprint(id)}))
+			if err != nil {
+				return err
+			}
+			return nil
+		})
+	}
+
+	if err := g.Wait(); err != nil {
+		t.Errorf("error in parallel cache entries read/write: %v", err)
+	}
+
+	// sanity checking
+	for i := 0; i < 5; i++ {
+		id := fmt.Sprint(i)
+		p := fmt.Sprintf("package%s", id)
+
+		es, err := cache.ReadEntries(dbName, p)
+		if err != nil {
+			t.Fatalf("failed to read entries: %v", err)
+		}
+		for _, e := range es {
+			if e.ID != id {
+				t.Errorf("want %s ID for vuln entry; got %s", id, e.ID)
+			}
+		}
+	}
+
+	// do similar for cache index
+	start := time.Now()
+	for i := 0; i < 1000; i++ {
+		i := i
+		g.Go(func() error {
+			id := i % 5
+			p := fmt.Sprintf("package%v", id)
+
+			idx, _, err := cache.ReadIndex(dbName)
+			if err != nil {
+				return err
+			}
+
+			if idx == nil {
+				idx = client.DBIndex{}
+			}
+
+			// sanity checking
+			if rt, ok := idx[p]; ok && !start.Before(rt) {
+				return fmt.Errorf("unexpected past time in index: start %v not before %v", start, rt)
+			}
+
+			now := time.Now()
+			idx[p] = now
+			if err := cache.WriteIndex(dbName, idx, now); err != nil {
+				return err
+			}
+			return nil
+		})
+	}
+
+	if err := g.Wait(); err != nil {
+		t.Errorf("error in parallel cache index read/write: %v", err)
+	}
+}
diff --git a/go.mod b/go.mod
index 5da6ce5..d03cb52 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@
 	github.com/client9/misspell v0.3.4
 	github.com/google/go-cmp v0.5.6
 	golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57
+	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
 	golang.org/x/tools v0.1.8
 	honnef.co/go/tools v0.2.2
 	mvdan.cc/unparam v0.0.0-20211214103731-d0ef000c54e5
diff --git a/go.sum b/go.sum
index 8c71487..d6c958d 100644
--- a/go.sum
+++ b/go.sum
@@ -28,6 +28,7 @@
 golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=