vulndb/client: remove cache implementation and expose db interface

Change-Id: Ie0a81b52b684c2e02e857c8105f229689593e172
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/347449
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/client/cache.go b/client/cache.go
index 20079cc..a9a540f 100644
--- a/client/cache.go
+++ b/client/cache.go
@@ -5,122 +5,15 @@
 package client
 
 import (
-	"encoding/json"
-	"go/build"
-	"io/ioutil"
-	"os"
-	"path/filepath"
 	"time"
 
 	"golang.org/x/vulndb/osv"
 )
 
-// NOTE: this cache implementation should be internal to the go tooling
-// (i.e. cmd/go/internal/something) so that the vulndb cache is owned
-// by the go command. Also it is currently NOT CONCURRENCY SAFE since
-// it does not implement file locking. When ported to the stdlib it
-// should use cmd/go/internal/lockedfile.
-
-// The cache uses a single JSON index file for each vulnerability database
-// which contains the map from packages to the time the last
-// vulnerability for that package was added/modified and the time that
-// the index was retrieved from the vulnerability database. The JSON
-// format is as follows:
-//
-// $GOPATH/pkg/mod/cache/download/vulndb/{db hostname}/indexes/index.json
-//   {
-//       Retrieved time.Time
-//       Index osv.DBIndex
-//   }
-//
-// Each package also has a JSON file which contains the array of vulnerability
-// entries for the package. The JSON format is as follows:
-//
-// $GOPATH/pkg/mod/cache/download/vulndb/{db hostname}/{import path}/vulns.json
-//   []*osv.Entry
-
+// Cache interface for vuln DB caching.
 type Cache interface {
 	ReadIndex(string) (osv.DBIndex, time.Time, error)
 	WriteIndex(string, osv.DBIndex, time.Time) error
 	ReadEntries(string, string) ([]*osv.Entry, error)
 	WriteEntries(string, string, []*osv.Entry) error
 }
-
-type fsCache struct{}
-
-// NewFsCache returns a fresh filesystem cache.
-// TODO: remove once the cache implementation reaches the go tooling repo.
-func NewFsCache() Cache {
-	return &fsCache{}
-}
-
-// should be cfg.GOMODCACHE when doing this inside the cmd/go/internal
-var cacheRoot = filepath.Join(build.Default.GOPATH, "/pkg/mod/cache/download/vulndb")
-
-type cachedIndex struct {
-	Retrieved time.Time
-	Index     osv.DBIndex
-}
-
-func (c *fsCache) ReadIndex(dbName string) (osv.DBIndex, time.Time, error) {
-	b, err := ioutil.ReadFile(filepath.Join(cacheRoot, dbName, "index.json"))
-	if err != nil {
-		if os.IsNotExist(err) {
-			return nil, time.Time{}, nil
-		}
-		return nil, time.Time{}, err
-	}
-	var index cachedIndex
-	if err := json.Unmarshal(b, &index); err != nil {
-		return nil, time.Time{}, err
-	}
-	return index.Index, index.Retrieved, nil
-}
-
-func (c *fsCache) WriteIndex(dbName string, index osv.DBIndex, retrieved time.Time) error {
-	path := filepath.Join(cacheRoot, dbName)
-	if err := os.MkdirAll(path, 0777); err != nil {
-		return err
-	}
-	j, err := json.Marshal(cachedIndex{
-		Index:     index,
-		Retrieved: retrieved,
-	})
-	if err != nil {
-		return err
-	}
-	if err := ioutil.WriteFile(filepath.Join(path, "index.json"), j, 0666); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (c *fsCache) ReadEntries(dbName string, p string) ([]*osv.Entry, error) {
-	b, err := ioutil.ReadFile(filepath.Join(cacheRoot, dbName, p, "vulns.json"))
-	if err != nil {
-		if os.IsNotExist(err) {
-			return nil, nil
-		}
-		return nil, err
-	}
-	var entries []*osv.Entry
-	if err := json.Unmarshal(b, &entries); err != nil {
-		return nil, err
-	}
-	return entries, nil
-}
-
-func (c *fsCache) WriteEntries(dbName string, p string, entries []*osv.Entry) error {
-	path := filepath.Join(cacheRoot, dbName, p)
-	if err := os.MkdirAll(path, 0777); err != nil {
-		return err
-	}
-	j, err := json.Marshal(entries)
-	if err != nil {
-		return err
-	}
-	if err := ioutil.WriteFile(filepath.Join(path, "vulns.json"), j, 0666); err != nil {
-		return err
-	}
-	return nil
-}
diff --git a/client/cache_test.go b/client/cache_test.go
deleted file mode 100644
index 6510072..0000000
--- a/client/cache_test.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2021 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package client
-
-import (
-	"os"
-	"path/filepath"
-	"reflect"
-	"testing"
-	"time"
-
-	"golang.org/x/vulndb/osv"
-)
-
-func TestCache(t *testing.T) {
-	originalRoot := cacheRoot
-	defer func() { cacheRoot = originalRoot }()
-
-	tmpDir := t.TempDir()
-	cacheRoot = tmpDir
-
-	cache := &fsCache{}
-	dbName := "vulndb.golang.org"
-
-	_, _, err := cache.ReadIndex(dbName)
-	if err != nil {
-		t.Fatalf("ReadIndex failed for non-existent database: %v", err)
-	}
-
-	if err = os.Mkdir(filepath.Join(tmpDir, dbName), 0777); err != nil {
-		t.Fatalf("os.Mkdir failed: %v", err)
-	}
-	_, _, err = cache.ReadIndex(dbName)
-	if err != nil {
-		t.Fatalf("ReadIndex failed for database without cached index: %v", err)
-	}
-
-	now := time.Now()
-	expectedIdx := osv.DBIndex{
-		"a.vuln.example.com": time.Time{}.Add(time.Hour),
-		"b.vuln.example.com": time.Time{}.Add(time.Hour * 2),
-		"c.vuln.example.com": time.Time{}.Add(time.Hour * 3),
-	}
-	if err = cache.WriteIndex(dbName, expectedIdx, now); err != nil {
-		t.Fatalf("WriteIndex failed to write index: %v", err)
-	}
-
-	idx, retrieved, err := cache.ReadIndex(dbName)
-	if err != nil {
-		t.Fatalf("ReadIndex failed for database with cached index: %v", err)
-	}
-	if !reflect.DeepEqual(idx, expectedIdx) {
-		t.Errorf("ReadIndex returned unexpected index, got:\n%s\nwant:\n%s", idx, expectedIdx)
-	}
-	if !retrieved.Equal(now) {
-		t.Errorf("ReadIndex returned unexpected retrieved: got %s, want %s", retrieved, now)
-	}
-
-	if _, err = cache.ReadEntries(dbName, "vuln.example.com"); err != nil {
-		t.Fatalf("ReadEntires failed for non-existent package: %v", err)
-	}
-
-	expectedEntries := []*osv.Entry{
-		&osv.Entry{ID: "001"},
-		&osv.Entry{ID: "002"},
-		&osv.Entry{ID: "003"},
-	}
-	if err := cache.WriteEntries(dbName, "vuln.example.com", expectedEntries); err != nil {
-		t.Fatalf("WriteEntries failed: %v", err)
-	}
-
-	entries, err := cache.ReadEntries(dbName, "vuln.example.com")
-	if err != nil {
-		t.Fatalf("ReadEntries failed for cached package: %v", err)
-	}
-	if !reflect.DeepEqual(entries, expectedEntries) {
-		t.Errorf("ReadEntries returned unexpected entries, got:\n%v\nwant:\n%v", entries, expectedEntries)
-	}
-}
diff --git a/client/client.go b/client/client.go
index b789ccc..6aa10ad 100644
--- a/client/client.go
+++ b/client/client.go
@@ -46,6 +46,11 @@
 	"golang.org/x/vulndb/osv"
 )
 
+// Client interface for fetching vulnerabilities based on module path
+type Client interface {
+	Get(string) ([]*osv.Entry, error)
+}
+
 type source interface {
 	Get(string) ([]*osv.Entry, error)
 	Index() (osv.DBIndex, error)
@@ -199,7 +204,7 @@
 	return e, nil
 }
 
-type Client struct {
+type client struct {
 	sources []source
 }
 
@@ -208,8 +213,8 @@
 	HTTPCache  Cache
 }
 
-func NewClient(sources []string, opts Options) (*Client, error) {
-	c := &Client{}
+func NewClient(sources []string, opts Options) (Client, error) {
+	c := &client{}
 	for _, uri := range sources {
 		uri = strings.TrimRight(uri, "/")
 		// should parse the URI out here instead of in there
@@ -239,7 +244,7 @@
 	return c, nil
 }
 
-func (c *Client) Get(module string) ([]*osv.Entry, error) {
+func (c *client) Get(module string) ([]*osv.Entry, error) {
 	var entries []*osv.Entry
 	// probably should be parallelized
 	for _, s := range c.sources {
diff --git a/client/client_test.go b/client/client_test.go
index af371c0..cd36fa8 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -40,21 +40,70 @@
 	fmt.Fprintf(w, index)
 }
 
+// testCache for testing purposes
+type testCache struct {
+	indexMap   map[string]osv.DBIndex
+	indexStamp map[string]time.Time
+	vulnMap    map[string]map[string][]*osv.Entry
+}
+
+func freshTestCache() *testCache {
+	return &testCache{
+		indexMap:   make(map[string]osv.DBIndex),
+		indexStamp: make(map[string]time.Time),
+		vulnMap:    make(map[string]map[string][]*osv.Entry),
+	}
+}
+
+func (tc *testCache) ReadIndex(db string) (osv.DBIndex, time.Time, error) {
+	index, ok := tc.indexMap[db]
+	if !ok {
+		return nil, time.Time{}, nil
+	}
+	stamp, ok := tc.indexStamp[db]
+	if !ok {
+		return nil, time.Time{}, nil
+	}
+	return index, stamp, nil
+}
+
+func (tc *testCache) WriteIndex(db string, index osv.DBIndex, stamp time.Time) error {
+	tc.indexMap[db] = index
+	tc.indexStamp[db] = stamp
+	return nil
+}
+
+func (tc *testCache) ReadEntries(db, module string) ([]*osv.Entry, error) {
+	mMap, ok := tc.vulnMap[db]
+	if !ok {
+		return nil, nil
+	}
+	return mMap[module], nil
+}
+
+func (tc *testCache) WriteEntries(db, module string, entries []*osv.Entry) error {
+	mMap, ok := tc.vulnMap[db]
+	if !ok {
+		mMap = make(map[string][]*osv.Entry)
+		tc.vulnMap[db] = mMap
+	}
+	mMap[module] = append(mMap[module], entries...)
+	return nil
+}
+
 // cachedTestVuln returns a function creating a local cache
 // for db with `dbName` with a version of testVuln where
 // Summary="cached" and LastModified happened after entry
 // in the `index` for the same pkg.
-func cachedTestVuln(dbName string) func() Cache {
-	return func() Cache {
-		c := &fsCache{}
-		e := &osv.Entry{
-			ID:       "ID1",
-			Details:  "cached",
-			Modified: time.Now(),
-		}
-		c.WriteEntries(dbName, "golang.org/example/one", []*osv.Entry{e})
-		return c
+func cachedTestVuln(dbName string) Cache {
+	c := freshTestCache()
+	e := &osv.Entry{
+		ID:       "ID1",
+		Details:  "cached",
+		Modified: time.Now(),
 	}
+	c.WriteEntries(dbName, "golang.org/example/one", []*osv.Entry{e})
+	return c
 }
 
 // createDirAndFile creates a directory `dir` if such directory does
@@ -103,28 +152,25 @@
 	defer os.RemoveAll(localDBName)
 
 	for _, test := range []struct {
-		name        string
-		source      string
-		createCache func() Cache
+		name   string
+		source string
+		cache  Cache
 		// cache summary for testVuln
 		summary string
 	}{
 		// Test the http client without any cache.
-		{name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, summary: ""},
+		{name: "http-no-cache", source: "http://localhost:" + port, cache: nil, summary: ""},
 		// Test the http client with empty cache.
-		{name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, summary: ""},
+		{name: "http-empty-cache", source: "http://localhost:" + port, cache: freshTestCache(), summary: ""},
 		// Test the client with non-stale cache containing a version of testVuln2 where Summary="cached".
-		{name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln("localhost"), summary: "cached"},
+		{name: "http-cache", source: "http://localhost:" + port, cache: cachedTestVuln("localhost"), summary: "cached"},
 		// Repeat the same for local file client.
-		{name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, summary: ""},
-		{name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, summary: ""},
+		{name: "file-no-cache", source: "file://" + localDBName, cache: nil, summary: ""},
+		{name: "file-empty-cache", source: "file://" + localDBName, cache: freshTestCache(), summary: ""},
 		// Cache does not play a role in local file databases.
-		{name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln(localDBName), summary: ""},
+		{name: "file-cache", source: "file://" + localDBName, cache: cachedTestVuln(localDBName), summary: ""},
 	} {
-		// Create fresh cache location each time.
-		cacheRoot = t.TempDir()
-
-		client, err := NewClient([]string{test.source}, Options{HTTPCache: test.createCache()})
+		client, err := NewClient([]string{test.source}, Options{HTTPCache: test.cache})
 		if err != nil {
 			t.Fatal(err)
 		}