internal/pkgsite: make pkgsite cache thread-safe
Add a mutex to protect the cache and add a test that would fail
with the "-race" flag without this fix.
Change-Id: I13e2bbd4d6019f425959cd8c660b4b6123ca8162
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/554795
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Tatiana Bradley <tatianabradley@google.com>
diff --git a/internal/pkgsite/client.go b/internal/pkgsite/client.go
index e9d1987..7ef4d59 100644
--- a/internal/pkgsite/client.go
+++ b/internal/pkgsite/client.go
@@ -15,6 +15,7 @@
"path/filepath"
"strconv"
"strings"
+ "sync"
"testing"
"time"
@@ -23,11 +24,8 @@
)
type Client struct {
- url string
- // Cache of module paths already seen.
- seen map[string]bool
- // Does seen contain all known modules?
- cacheComplete bool
+ url string
+ cache *cache
}
func Default() *Client {
@@ -36,17 +34,13 @@
func New(url string) *Client {
return &Client{
- url: url,
- seen: make(map[string]bool),
- cacheComplete: false,
+ url: url,
+ cache: newCache(),
}
}
func (pc *Client) SetKnownModules(known []string) {
- for _, km := range known {
- pc.seen[km] = true
- }
- pc.cacheComplete = true
+ pc.cache.setKnownModules(known)
}
// Limit pkgsite requests to this many per second.
@@ -64,13 +58,11 @@
// Known reports whether pkgsite knows that modulePath actually refers
// to a module.
func (pc *Client) Known(ctx context.Context, modulePath string) (bool, error) {
- // If we've seen it before, no need to call.
- if b, ok := pc.seen[modulePath]; ok {
- return b, nil
+ found, ok := pc.cache.lookup(modulePath)
+ if ok {
+ return found, nil
}
- if pc.cacheComplete {
- return false, nil
- }
+
// Pause to maintain a max QPS.
if err := pkgsiteRateLimiter.Wait(ctx); err != nil {
return false, err
@@ -92,7 +84,7 @@
return false, err
}
known := res.StatusCode == http.StatusOK
- pc.seen[modulePath] = known
+ pc.cache.add(modulePath, known)
return known, nil
}
@@ -115,7 +107,10 @@
return seen, nil
}
-func (c *Client) writeKnown(w io.Writer) error {
+func (c *cache) writeKnown(w io.Writer) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
b, err := json.MarshalIndent(c.seen, "", " ")
if err != nil {
return err
@@ -164,7 +159,7 @@
if useRealPkgsite {
c := Default()
t.Cleanup(func() {
- err := c.writeKnown(rw)
+ err := c.cache.writeKnown(rw)
if err != nil {
t.Error(err)
}
@@ -184,3 +179,54 @@
t.Cleanup(s.Close)
return New(s.URL), nil
}
+
+type cache struct {
+ mu sync.Mutex
+ // Module paths already seen.
+ seen map[string]bool
+ // Does the cache contain all known modules?
+ complete bool
+}
+
+func newCache() *cache {
+ return &cache{
+ seen: make(map[string]bool),
+ complete: false,
+ }
+}
+
+func (c *cache) setKnownModules(known []string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ for _, km := range known {
+ c.seen[km] = true
+ }
+ c.complete = true
+}
+
+func (c *cache) lookup(modulePath string) (known bool, ok bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // In the cache.
+ if known, ok := c.seen[modulePath]; ok {
+ return known, true
+ }
+
+ // Not in the cache, but the cache is complete, so this
+ // module is not known.
+ if c.complete {
+ return false, true
+ }
+
+ // We can't make a statement about this module.
+ return false, false
+}
+
+func (c *cache) add(modulePath string, known bool) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.seen[modulePath] = known
+}
diff --git a/internal/pkgsite/client_test.go b/internal/pkgsite/client_test.go
index 7f3d952..153c522 100644
--- a/internal/pkgsite/client_test.go
+++ b/internal/pkgsite/client_test.go
@@ -43,3 +43,38 @@
})
}
}
+
+func TestKnownParallel(t *testing.T) {
+ ctx := context.Background()
+ cf, err := CacheFile(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pc, err := TestClient(t, *usePkgsite, cf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, test := range []struct {
+ name string
+ in string
+ want bool
+ }{
+ {name: "valid", in: "golang.org/x/mod", want: true},
+ {name: "invalid", in: "github.com/something/something", want: false},
+ } {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := pc.Known(ctx, test.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got != test.want {
+ t.Errorf("%s: got %t, want %t", test.in, got, test.want)
+ }
+ })
+ }
+}
diff --git a/internal/pkgsite/testdata/pkgsite/TestKnownParallel.json b/internal/pkgsite/testdata/pkgsite/TestKnownParallel.json
new file mode 100644
index 0000000..45abe60
--- /dev/null
+++ b/internal/pkgsite/testdata/pkgsite/TestKnownParallel.json
@@ -0,0 +1,4 @@
+{
+ "github.com/something/something": false,
+ "golang.org/x/mod": true
+}
\ No newline at end of file