vulndb/client: make Get accept a single module path
Current signature of Get accepts a list of module paths. A single
argument is a cleaner solution not affecting client library usability.
This CL makes the switch and cleans up unit testing.
Change-Id: Ic67fa02e0372f19882b75c47ced8f1a2a9b3a233
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/344870
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/client/client.go b/client/client.go
index c7e811e..b789ccc 100644
--- a/client/client.go
+++ b/client/client.go
@@ -47,7 +47,7 @@
)
type source interface {
- Get([]string) ([]*osv.Entry, error)
+ Get(string) ([]*osv.Entry, error)
Index() (osv.DBIndex, error)
}
@@ -55,22 +55,18 @@
dir string
}
-func (ls *localSource) Get(modules []string) ([]*osv.Entry, error) {
- var entries []*osv.Entry
- for _, p := range modules {
- content, err := ioutil.ReadFile(filepath.Join(ls.dir, p+".json"))
- if os.IsNotExist(err) {
- continue
- } else if err != nil {
- return nil, err
- }
- var e []*osv.Entry
- if err = json.Unmarshal(content, &e); err != nil {
- return nil, err
- }
- entries = append(entries, e...)
+func (ls *localSource) Get(module string) ([]*osv.Entry, error) {
+ content, err := ioutil.ReadFile(filepath.Join(ls.dir, module+".json"))
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
}
- return entries, nil
+ var e []*osv.Entry
+ if err = json.Unmarshal(content, &e); err != nil {
+ return nil, err
+ }
+ return e, nil
}
func (ls *localSource) Index() (osv.DBIndex, error) {
@@ -147,69 +143,60 @@
return index, nil
}
-func (hs *httpSource) Get(modules []string) ([]*osv.Entry, error) {
- var entries []*osv.Entry
-
+func (hs *httpSource) Get(module string) ([]*osv.Entry, error) {
index, err := hs.Index()
if err != nil {
return nil, err
}
- var stillNeed []string
- for _, p := range modules {
- lastModified, present := index[p]
- if !present {
- continue
- }
- if hs.cache != nil {
- if cached, err := hs.cache.ReadEntries(hs.dbName, p); err != nil {
- return nil, err
- } else if len(cached) != 0 {
- var stale bool
- for _, c := range cached {
- if c.Modified.Before(lastModified) {
- stale = true
- break
- }
- }
- if !stale {
- entries = append(entries, cached...)
- continue
- }
- }
- }
- stillNeed = append(stillNeed, p)
+ lastModified, present := index[module]
+ if !present {
+ return nil, nil
}
- for _, p := range stillNeed {
- resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, p))
- if err != nil {
+ if hs.cache != nil {
+ if cached, err := hs.cache.ReadEntries(hs.dbName, module); err != nil {
return nil, err
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusNotFound {
- continue
- }
- // might want this to be a LimitedReader
- content, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- return nil, err
- }
- var e []*osv.Entry
- if err = json.Unmarshal(content, &e); err != nil {
- return nil, err
- }
- // TODO: we may want to check that the returned entries actually match
- // the module we asked about, so that the cache cannot be poisoned
- entries = append(entries, e...)
-
- if hs.cache != nil {
- if err := hs.cache.WriteEntries(hs.dbName, p, e); err != nil {
- return nil, err
+ } else if len(cached) != 0 {
+ var stale bool
+ for _, c := range cached {
+ if c.Modified.Before(lastModified) {
+ stale = true
+ break
+ }
+ }
+ if !stale {
+ return cached, nil
}
}
}
- return entries, nil
+
+ resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, module))
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, nil
+ }
+ // might want this to be a LimitedReader
+ content, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ var e []*osv.Entry
+ // TODO: we may want to check that the returned entries actually match
+ // the module we asked about, so that the cache cannot be poisoned
+ if err = json.Unmarshal(content, &e); err != nil {
+ return nil, err
+ }
+
+ if hs.cache != nil {
+ if err := hs.cache.WriteEntries(hs.dbName, module, e); err != nil {
+ return nil, err
+ }
+ }
+ return e, nil
}
type Client struct {
@@ -252,11 +239,11 @@
return c, nil
}
-func (c *Client) Get(modules []string) ([]*osv.Entry, error) {
+func (c *Client) Get(module string) ([]*osv.Entry, error) {
var entries []*osv.Entry
// probably should be parallelized
for _, s := range c.sources {
- e, err := s.Get(modules)
+ e, err := s.Get(module)
if err != nil {
return nil, err // be failure tolerant?
}
diff --git a/client/client_test.go b/client/client_test.go
index c8c04cd..af371c0 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -21,49 +21,38 @@
"golang.org/x/vulndb/osv"
)
-var testVuln1 string = `[
- {"ID":"ID1","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
+var testVuln string = `[
+ {"ID":"ID","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
"Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.2.0"}]},
"ecosystem_specific":{"Symbols":["some_symbol_1"]
}}]`
-var testVuln2 string = `[
- {"ID":"ID2","Package":{"Name":"golang.org/example/two","Ecosystem":"go"}, "Summary":"",
- "Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.1.0"}]},
- "ecosystem_specific":{"Symbols":["some_symbol_2"]
- }}]`
-
-// index containing timestamps for packages in testVuln1 and testVuln2.
+// index containing timestamps for package in testVuln.
var index string = `{
- "golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00",
- "golang.org/example/two": "2019-02-05T09:00:00.31561157-07:00"
+ "golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00"
}`
-func serveTestVuln1(w http.ResponseWriter, req *http.Request) {
- fmt.Fprintf(w, testVuln1)
-}
-
-func serveTestVuln2(w http.ResponseWriter, req *http.Request) {
- fmt.Fprintf(w, testVuln2)
+func serveTestVuln(w http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(w, testVuln)
}
func serveIndex(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, index)
}
-// cachedTestVuln2 returns a function creating a local cache
-// for db with `dbName` with a version of testVuln2 where
+// cachedTestVuln returns a function creating a local cache
+// for db with `dbName` with a version of testVuln where
// Summary="cached" and LastModified happened after entry
// in the `index` for the same pkg.
-func cachedTestVuln2(dbName string) func() Cache {
+func cachedTestVuln(dbName string) func() Cache {
return func() Cache {
c := &fsCache{}
e := &osv.Entry{
- ID: "ID2",
+ ID: "ID1",
Details: "cached",
Modified: time.Now(),
}
- c.WriteEntries(dbName, "golang.org/example/two", []*osv.Entry{e})
+ c.WriteEntries(dbName, "golang.org/example/one", []*osv.Entry{e})
return c
}
}
@@ -81,10 +70,7 @@
func localDB(t *testing.T) (string, error) {
dbName := t.TempDir()
- if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln1); err != nil {
- return "", err
- }
- if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "two.json", testVuln2); err != nil {
+ if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln); err != nil {
return "", err
}
if err := createDirAndFile(path.Join(dbName, ""), "index.json", index); err != nil {
@@ -99,8 +85,7 @@
}
// Create a local http database.
- http.HandleFunc("/golang.org/example/one.json", serveTestVuln1)
- http.HandleFunc("/golang.org/example/two.json", serveTestVuln2)
+ http.HandleFunc("/golang.org/example/one.json", serveTestVuln)
http.HandleFunc("/index.json", serveIndex)
l, err := net.Listen("tcp", "127.0.0.1:")
@@ -121,20 +106,20 @@
name string
source string
createCache func() Cache
- noVulns int
- summaries map[string]string
+ // cache summary for testVuln
+ summary string
}{
// Test the http client without any cache.
- {name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+ {name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, summary: ""},
// Test the http client with empty cache.
- {name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+ {name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, summary: ""},
// Test the client with non-stale cache containing a version of testVuln2 where Summary="cached".
- {name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln2("localhost"), noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": "cached"}},
+ {name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln("localhost"), summary: "cached"},
// Repeat the same for local file client.
- {name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
- {name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+ {name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, summary: ""},
+ {name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, summary: ""},
// Cache does not play a role in local file databases.
- {name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln2(localDBName), noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+ {name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln(localDBName), summary: ""},
} {
// Create fresh cache location each time.
cacheRoot = t.TempDir()
@@ -144,18 +129,17 @@
t.Fatal(err)
}
- vulns, err := client.Get([]string{"golang.org/example/one", "golang.org/example/two"})
+ vulns, err := client.Get("golang.org/example/one")
if err != nil {
t.Fatal(err)
}
- if len(vulns) != test.noVulns {
- t.Errorf("want %v vulns for %s; got %v", test.noVulns, test.name, len(vulns))
+
+ if len(vulns) != 1 {
+ t.Errorf("%s: want 1 vuln for golang.org/example/one; got %v", test.name, len(vulns))
}
- for _, v := range vulns {
- if s, ok := test.summaries[v.ID]; !ok || v.Details != s {
- t.Errorf("want '%s' summary for vuln with id %v in %s; got '%s'", s, v.ID, test.name, v.Details)
- }
+ if v := vulns[0]; v.Details != test.summary {
+ t.Errorf("%s: want '%s' summary for testVuln; got '%s'", test.name, test.summary, v.Details)
}
}
}
@@ -177,11 +161,12 @@
defer ts.Close()
hs := &httpSource{url: ts.URL, c: new(http.Client)}
- _, err := hs.Get([]string{"a", "b", "c"})
- if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ for _, module := range []string{"a", "b", "c"} {
+ if _, err := hs.Get(module); err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
}
- expectedFetches := map[string]int{"/index.json": 1, "/a.json": 1, "/b.json": 1}
+ expectedFetches := map[string]int{"/index.json": 3, "/a.json": 1, "/b.json": 1}
if !reflect.DeepEqual(fetches, expectedFetches) {
t.Errorf("unexpected fetches, got %v, want %v", fetches, expectedFetches)
}