client: support fetching entries by ID
Change-Id: I5ae5fee45af4471aab26c48769509b57c63a0a2d
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/355969
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
diff --git a/client/client.go b/client/client.go
index 27b1f4f..298d8b4 100644
--- a/client/client.go
+++ b/client/client.go
@@ -46,13 +46,17 @@
"golang.org/x/vulndb/osv"
)
-// Client interface for fetching vulnerabilities based on module path
+// Client interface for fetching vulnerabilities based on module path or ID.
type Client interface {
+ // TODO(jba): rename to GetByModule
Get(string) ([]*osv.Entry, error)
+ GetByID(string) (*osv.Entry, error)
}
+const idDir = "byID"
+
type source interface {
- Get(string) ([]*osv.Entry, error)
+ Client
Index() (osv.DBIndex, error)
}
@@ -74,6 +78,20 @@
return e, nil
}
+func (ls *localSource) GetByID(id string) (*osv.Entry, error) {
+ content, err := ioutil.ReadFile(filepath.Join(ls.dir, idDir, id+".json"))
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ var e osv.Entry
+ if err = json.Unmarshal(content, &e); err != nil {
+ return nil, err
+ }
+ return &e, nil
+}
+
func (ls *localSource) Index() (osv.DBIndex, error) {
var index osv.DBIndex
b, err := ioutil.ReadFile(filepath.Join(ls.dir, "index.json"))
@@ -183,17 +201,8 @@
}
}
- resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, module))
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusNotFound {
- return nil, nil
- }
- // might want this to be a LimitedReader
- content, err := ioutil.ReadAll(resp.Body)
- if err != nil {
+ content, err := hs.readBody(fmt.Sprintf("%s/%s.json", hs.url, module))
+ if err != nil || content == nil {
return nil, err
}
var e []*osv.Entry
@@ -211,6 +220,32 @@
return e, nil
}
+func (hs *httpSource) GetByID(id string) (*osv.Entry, error) {
+ // TODO(jba): cache?
+ content, err := hs.readBody(fmt.Sprintf("%s/%s/%s.json", hs.url, idDir, id))
+ if err != nil || content == nil {
+ return nil, err
+ }
+ var e osv.Entry
+ if err := json.Unmarshal(content, &e); err != nil {
+ return nil, err
+ }
+ return &e, nil
+}
+
+func (hs *httpSource) readBody(url string) ([]byte, error) {
+ resp, err := hs.c.Get(url)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, nil
+ }
+ // might want this to be a LimitedReader
+ return ioutil.ReadAll(resp.Body)
+}
+
type client struct {
sources []source
}
@@ -263,3 +298,16 @@
}
return entries, nil
}
+
+func (c *client) GetByID(id string) (*osv.Entry, error) {
+ for _, s := range c.sources {
+ entry, err := s.GetByID(id)
+ if err != nil {
+ return nil, err // be failure tolerant?
+ }
+ if entry != nil {
+ return entry, nil
+ }
+ }
+ return nil, nil
+}
diff --git a/client/client_test.go b/client/client_test.go
index 09adda1..4a00220 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -19,26 +19,29 @@
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"golang.org/x/vulndb/osv"
)
-var testVuln string = `[
+var (
+ testVuln = `
{"ID":"ID","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
"Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.2.0"}]},
"ecosystem_specific":{"Symbols":["some_symbol_1"]
- }}]`
+ }}`
+
+ testVulns = "[" + testVuln + "]"
+)
// index containing timestamps for package in testVuln.
var index string = `{
"golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00"
}`
-func serveTestVuln(w http.ResponseWriter, req *http.Request) {
- fmt.Fprint(w, testVuln)
-}
-
-func serveIndex(w http.ResponseWriter, req *http.Request) {
- fmt.Fprint(w, index)
+func dataHandler(data string) http.HandlerFunc {
+ return func(w http.ResponseWriter, _ *http.Request) {
+ fmt.Fprint(w, data)
+ }
}
// testCache for testing purposes
@@ -116,16 +119,19 @@
return ioutil.WriteFile(path.Join(dir, file), []byte(content), 0644)
}
-// localDB creates a local db with testVuln1, testVuln2, and index as contents.
+// localDB creates a local db with testVulns and index as contents.
func localDB(t *testing.T) (string, error) {
dbName := t.TempDir()
- if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln); err != nil {
+ if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVulns); err != nil {
return "", err
}
if err := createDirAndFile(path.Join(dbName, ""), "index.json", index); err != nil {
return "", err
}
+ if err := createDirAndFile(path.Join(dbName, idDir), "ID.json", testVuln); err != nil {
+ return "", err
+ }
return dbName, nil
}
@@ -135,8 +141,8 @@
}
// Create a local http database.
- http.HandleFunc("/golang.org/example/one.json", serveTestVuln)
- http.HandleFunc("/index.json", serveIndex)
+ http.HandleFunc("/golang.org/example/one.json", dataHandler(testVulns))
+ http.HandleFunc("/index.json", dataHandler(index))
l, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
@@ -255,3 +261,49 @@
t.Errorf("want %v vuln; got %v", e, vulns)
}
}
+
+func TestClientByID(t *testing.T) {
+ if runtime.GOOS == "js" {
+ t.Skip("skipping test: no network on js")
+ }
+
+ http.HandleFunc("/byID/ID.json", dataHandler(testVuln))
+ l, err := net.Listen("tcp", "127.0.0.1:")
+ if err != nil {
+ t.Fatalf("failed to listen on 127.0.0.1: %s", err)
+ }
+ _, port, _ := net.SplitHostPort(l.Addr().String())
+ go func() { http.Serve(l, nil) }()
+
+ // Create a local file database.
+ localDBName, err := localDB(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(localDBName)
+
+ var want osv.Entry
+ if err := json.Unmarshal([]byte(testVuln), &want); err != nil {
+ t.Fatal(err)
+ }
+ for _, test := range []struct {
+ name string
+ source string
+ }{
+ {name: "http", source: "http://localhost:" + port},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ client, err := NewClient([]string{test.source}, Options{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := client.GetByID("ID")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cmp.Equal(got, &want) {
+ t.Errorf("got\n%+v\nwant\n%+v", got, &want)
+ }
+ })
+ }
+}