cmd/govulncheck: adds basic support for cache thread safety
Change-Id: Ifb79abf2f863787d37c152de9c668138597c1dc7
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/396594
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/cmd/govulncheck/cache.go b/cmd/govulncheck/cache.go
index 72feb24..26dd00f 100644
--- a/cmd/govulncheck/cache.go
+++ b/cmd/govulncheck/cache.go
@@ -13,18 +13,13 @@
"io/ioutil"
"os"
"path/filepath"
+ "sync"
"time"
"golang.org/x/vuln/client"
"golang.org/x/vuln/osv"
)
-// NOTE: this cache implementation should be kept internal to the go tooling
-// (i.e. cmd/go/internal/something) so that the vulndb cache is owned by the
-// go command. Also it is currently NOT CONCURRENCY SAFE since it does not
-// implement file locking. If ported to the stdlib it should use
-// cmd/go/internal/lockedfile.
-
// The cache uses a single JSON index file for each vulnerability database
// which contains the map from packages to the time the last
// vulnerability for that package was added/modified and the time that
@@ -43,9 +38,11 @@
// $GOPATH/pkg/mod/cache/download/vulndb/{db hostname}/{import path}/vulns.json
// []*osv.Entry
-// fsCache is file-system cache implementing osv.Cache
-// TODO: make cache thread-safe
+// fsCache is a thread-safe file-system cache implementing osv.Cache
+//
+// TODO: use something like cmd/go/internal/lockedfile for thread safety?
type fsCache struct {
+ mu sync.Mutex
rootDir string
}
@@ -62,6 +59,9 @@
}
func (c *fsCache) ReadIndex(dbName string) (client.DBIndex, time.Time, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
b, err := ioutil.ReadFile(filepath.Join(c.rootDir, dbName, "index.json"))
if err != nil {
if os.IsNotExist(err) {
@@ -77,6 +77,9 @@
}
func (c *fsCache) WriteIndex(dbName string, index client.DBIndex, retrieved time.Time) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
path := filepath.Join(c.rootDir, dbName)
if err := os.MkdirAll(path, 0755); err != nil {
return err
@@ -95,6 +98,9 @@
}
func (c *fsCache) ReadEntries(dbName string, p string) ([]*osv.Entry, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
b, err := ioutil.ReadFile(filepath.Join(c.rootDir, dbName, p, "vulns.json"))
if err != nil {
if os.IsNotExist(err) {
@@ -110,6 +116,9 @@
}
func (c *fsCache) WriteEntries(dbName string, p string, entries []*osv.Entry) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
path := filepath.Join(c.rootDir, dbName, p)
if err := os.MkdirAll(path, 0777); err != nil {
return err
diff --git a/cmd/govulncheck/cache_test.go b/cmd/govulncheck/cache_test.go
index be53956..09c8f11 100644
--- a/cmd/govulncheck/cache_test.go
+++ b/cmd/govulncheck/cache_test.go
@@ -8,12 +8,14 @@
package main
import (
+ "fmt"
"os"
"path/filepath"
"reflect"
"testing"
"time"
+ "golang.org/x/sync/errgroup"
"golang.org/x/vuln/client"
"golang.org/x/vuln/osv"
)
@@ -79,3 +81,85 @@
t.Errorf("ReadEntries returned unexpected entries, got:\n%v\nwant:\n%v", entries, expectedEntries)
}
}
+
+func TestConcurrency(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ cache := &fsCache{rootDir: tmpDir}
+ dbName := "vulndb.golang.org"
+
+ g := new(errgroup.Group)
+ for i := 0; i < 1000; i++ {
+ i := i
+ g.Go(func() error {
+ id := i % 5
+ p := fmt.Sprintf("package%d", id)
+
+ entries, err := cache.ReadEntries(dbName, p)
+ if err != nil {
+ return err
+ }
+
+ err = cache.WriteEntries(dbName, p, append(entries, &osv.Entry{ID: fmt.Sprint(id)}))
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ }
+
+ if err := g.Wait(); err != nil {
+ t.Errorf("error in parallel cache entries read/write: %v", err)
+ }
+
+ // sanity checking
+ for i := 0; i < 5; i++ {
+ id := fmt.Sprint(i)
+ p := fmt.Sprintf("package%s", id)
+
+ es, err := cache.ReadEntries(dbName, p)
+ if err != nil {
+ t.Fatalf("failed to read entries: %v", err)
+ }
+ for _, e := range es {
+ if e.ID != id {
+ t.Errorf("want %s ID for vuln entry; got %s", id, e.ID)
+ }
+ }
+ }
+
+ // do similar for cache index
+ start := time.Now()
+ for i := 0; i < 1000; i++ {
+ i := i
+ g.Go(func() error {
+ id := i % 5
+ p := fmt.Sprintf("package%v", id)
+
+ idx, _, err := cache.ReadIndex(dbName)
+ if err != nil {
+ return err
+ }
+
+ if idx == nil {
+ idx = client.DBIndex{}
+ }
+
+ // sanity checking
+ if rt, ok := idx[p]; ok && !start.Before(rt) {
+ return fmt.Errorf("unexpected past time in index: start %v not before %v", start, rt)
+ }
+
+ now := time.Now()
+ idx[p] = now
+ if err := cache.WriteIndex(dbName, idx, now); err != nil {
+ return err
+ }
+ return nil
+ })
+ }
+
+ if err := g.Wait(); err != nil {
+ t.Errorf("error in parallel cache index read/write: %v", err)
+ }
+}
diff --git a/go.mod b/go.mod
index 5da6ce5..d03cb52 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@
github.com/client9/misspell v0.3.4
github.com/google/go-cmp v0.5.6
golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57
+ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/tools v0.1.8
honnef.co/go/tools v0.2.2
mvdan.cc/unparam v0.0.0-20211214103731-d0ef000c54e5
diff --git a/go.sum b/go.sum
index 8c71487..d6c958d 100644
--- a/go.sum
+++ b/go.sum
@@ -28,6 +28,7 @@
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=