client: support fetching entries by ID

Change-Id: I5ae5fee45af4471aab26c48769509b57c63a0a2d
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/355969
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
diff --git a/client/client.go b/client/client.go
index 27b1f4f..298d8b4 100644
--- a/client/client.go
+++ b/client/client.go
@@ -46,13 +46,17 @@
 	"golang.org/x/vulndb/osv"
 )
 
-// Client interface for fetching vulnerabilities based on module path
+// Client interface for fetching vulnerabilities based on module path or ID.
 type Client interface {
+	// TODO(jba): rename to GetByModule
 	Get(string) ([]*osv.Entry, error)
+	GetByID(string) (*osv.Entry, error)
 }
 
+const idDir = "byID"
+
 type source interface {
-	Get(string) ([]*osv.Entry, error)
+	Client
 	Index() (osv.DBIndex, error)
 }
 
@@ -74,6 +78,20 @@
 	return e, nil
 }
 
+func (ls *localSource) GetByID(id string) (*osv.Entry, error) {
+	content, err := ioutil.ReadFile(filepath.Join(ls.dir, idDir, id+".json"))
+	if os.IsNotExist(err) {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
+	}
+	var e osv.Entry
+	if err = json.Unmarshal(content, &e); err != nil {
+		return nil, err
+	}
+	return &e, nil
+}
+
 func (ls *localSource) Index() (osv.DBIndex, error) {
 	var index osv.DBIndex
 	b, err := ioutil.ReadFile(filepath.Join(ls.dir, "index.json"))
@@ -183,17 +201,8 @@
 		}
 	}
 
-	resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, module))
-	if err != nil {
-		return nil, err
-	}
-	defer resp.Body.Close()
-	if resp.StatusCode == http.StatusNotFound {
-		return nil, nil
-	}
-	// might want this to be a LimitedReader
-	content, err := ioutil.ReadAll(resp.Body)
-	if err != nil {
+	content, err := hs.readBody(fmt.Sprintf("%s/%s.json", hs.url, module))
+	if err != nil || content == nil {
 		return nil, err
 	}
 	var e []*osv.Entry
@@ -211,6 +220,32 @@
 	return e, nil
 }
 
+func (hs *httpSource) GetByID(id string) (*osv.Entry, error) {
+	// TODO(jba): cache?
+	content, err := hs.readBody(fmt.Sprintf("%s/%s/%s.json", hs.url, idDir, id))
+	if err != nil || content == nil {
+		return nil, err
+	}
+	var e osv.Entry
+	if err := json.Unmarshal(content, &e); err != nil {
+		return nil, err
+	}
+	return &e, nil
+}
+
+func (hs *httpSource) readBody(url string) ([]byte, error) {
+	resp, err := hs.c.Get(url)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode == http.StatusNotFound {
+		return nil, nil
+	}
+	// might want this to be a LimitedReader
+	return ioutil.ReadAll(resp.Body)
+}
+
 type client struct {
 	sources []source
 }
@@ -263,3 +298,16 @@
 	}
 	return entries, nil
 }
+
+func (c *client) GetByID(id string) (*osv.Entry, error) {
+	for _, s := range c.sources {
+		entry, err := s.GetByID(id)
+		if err != nil {
+			return nil, err // be failure tolerant?
+		}
+		if entry != nil {
+			return entry, nil
+		}
+	}
+	return nil, nil
+}
diff --git a/client/client_test.go b/client/client_test.go
index 09adda1..4a00220 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -19,26 +19,29 @@
 	"testing"
 	"time"
 
+	"github.com/google/go-cmp/cmp"
 	"golang.org/x/vulndb/osv"
 )
 
-var testVuln string = `[
+var (
+	testVuln = `
 	{"ID":"ID","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
 	 "Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.2.0"}]},
 	 "ecosystem_specific":{"Symbols":["some_symbol_1"]
-	}}]`
+	}}`
+
+	testVulns = "[" + testVuln + "]"
+)
 
 // index containing timestamps for package in testVuln.
 var index string = `{
 	"golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00"
 	}`
 
-func serveTestVuln(w http.ResponseWriter, req *http.Request) {
-	fmt.Fprint(w, testVuln)
-}
-
-func serveIndex(w http.ResponseWriter, req *http.Request) {
-	fmt.Fprint(w, index)
+func dataHandler(data string) http.HandlerFunc {
+	return func(w http.ResponseWriter, _ *http.Request) {
+		fmt.Fprint(w, data)
+	}
 }
 
 // testCache for testing purposes
@@ -116,16 +119,19 @@
 	return ioutil.WriteFile(path.Join(dir, file), []byte(content), 0644)
 }
 
-// localDB creates a local db with testVuln1, testVuln2, and index as contents.
+// localDB creates a local db with testVulns and index as contents.
 func localDB(t *testing.T) (string, error) {
 	dbName := t.TempDir()
 
-	if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln); err != nil {
+	if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVulns); err != nil {
 		return "", err
 	}
 	if err := createDirAndFile(path.Join(dbName, ""), "index.json", index); err != nil {
 		return "", err
 	}
+	if err := createDirAndFile(path.Join(dbName, idDir), "ID.json", testVuln); err != nil {
+		return "", err
+	}
 	return dbName, nil
 }
 
@@ -135,8 +141,8 @@
 	}
 
 	// Create a local http database.
-	http.HandleFunc("/golang.org/example/one.json", serveTestVuln)
-	http.HandleFunc("/index.json", serveIndex)
+	http.HandleFunc("/golang.org/example/one.json", dataHandler(testVulns))
+	http.HandleFunc("/index.json", dataHandler(index))
 
 	l, err := net.Listen("tcp", "127.0.0.1:")
 	if err != nil {
@@ -255,3 +261,49 @@
 		t.Errorf("want %v vuln; got %v", e, vulns)
 	}
 }
+
+func TestClientByID(t *testing.T) {
+	if runtime.GOOS == "js" {
+		t.Skip("skipping test: no network on js")
+	}
+
+	http.HandleFunc("/byID/ID.json", dataHandler(testVuln))
+	l, err := net.Listen("tcp", "127.0.0.1:")
+	if err != nil {
+		t.Fatalf("failed to listen on 127.0.0.1: %s", err)
+	}
+	_, port, _ := net.SplitHostPort(l.Addr().String())
+	go func() { http.Serve(l, nil) }()
+
+	// Create a local file database.
+	localDBName, err := localDB(t)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(localDBName)
+
+	var want osv.Entry
+	if err := json.Unmarshal([]byte(testVuln), &want); err != nil {
+		t.Fatal(err)
+	}
+	for _, test := range []struct {
+		name   string
+		source string
+	}{
+		{name: "http", source: "http://localhost:" + port},
+	} {
+		t.Run(test.name, func(t *testing.T) {
+			client, err := NewClient([]string{test.source}, Options{})
+			if err != nil {
+				t.Fatal(err)
+			}
+			got, err := client.GetByID("ID")
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !cmp.Equal(got, &want) {
+				t.Errorf("got\n%+v\nwant\n%+v", got, &want)
+			}
+		})
+	}
+}