vulndb/client: make Get accept a single module path

Current signature of Get accepts a list of module paths. A single
argument is a cleaner solution not affecting client library usability.
This CL makes the switch and cleans up unit testing.

Change-Id: Ic67fa02e0372f19882b75c47ced8f1a2a9b3a233
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/344870
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Trust: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/client/client.go b/client/client.go
index c7e811e..b789ccc 100644
--- a/client/client.go
+++ b/client/client.go
@@ -47,7 +47,7 @@
 )
 
 type source interface {
-	Get([]string) ([]*osv.Entry, error)
+	Get(string) ([]*osv.Entry, error)
 	Index() (osv.DBIndex, error)
 }
 
@@ -55,22 +55,18 @@
 	dir string
 }
 
-func (ls *localSource) Get(modules []string) ([]*osv.Entry, error) {
-	var entries []*osv.Entry
-	for _, p := range modules {
-		content, err := ioutil.ReadFile(filepath.Join(ls.dir, p+".json"))
-		if os.IsNotExist(err) {
-			continue
-		} else if err != nil {
-			return nil, err
-		}
-		var e []*osv.Entry
-		if err = json.Unmarshal(content, &e); err != nil {
-			return nil, err
-		}
-		entries = append(entries, e...)
+func (ls *localSource) Get(module string) ([]*osv.Entry, error) {
+	content, err := ioutil.ReadFile(filepath.Join(ls.dir, module+".json"))
+	if os.IsNotExist(err) {
+		return nil, nil
+	} else if err != nil {
+		return nil, err
 	}
-	return entries, nil
+	var e []*osv.Entry
+	if err = json.Unmarshal(content, &e); err != nil {
+		return nil, err
+	}
+	return e, nil
 }
 
 func (ls *localSource) Index() (osv.DBIndex, error) {
@@ -147,69 +143,60 @@
 	return index, nil
 }
 
-func (hs *httpSource) Get(modules []string) ([]*osv.Entry, error) {
-	var entries []*osv.Entry
-
+func (hs *httpSource) Get(module string) ([]*osv.Entry, error) {
 	index, err := hs.Index()
 	if err != nil {
 		return nil, err
 	}
 
-	var stillNeed []string
-	for _, p := range modules {
-		lastModified, present := index[p]
-		if !present {
-			continue
-		}
-		if hs.cache != nil {
-			if cached, err := hs.cache.ReadEntries(hs.dbName, p); err != nil {
-				return nil, err
-			} else if len(cached) != 0 {
-				var stale bool
-				for _, c := range cached {
-					if c.Modified.Before(lastModified) {
-						stale = true
-						break
-					}
-				}
-				if !stale {
-					entries = append(entries, cached...)
-					continue
-				}
-			}
-		}
-		stillNeed = append(stillNeed, p)
+	lastModified, present := index[module]
+	if !present {
+		return nil, nil
 	}
 
-	for _, p := range stillNeed {
-		resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, p))
-		if err != nil {
+	if hs.cache != nil {
+		if cached, err := hs.cache.ReadEntries(hs.dbName, module); err != nil {
 			return nil, err
-		}
-		defer resp.Body.Close()
-		if resp.StatusCode == http.StatusNotFound {
-			continue
-		}
-		// might want this to be a LimitedReader
-		content, err := ioutil.ReadAll(resp.Body)
-		if err != nil {
-			return nil, err
-		}
-		var e []*osv.Entry
-		if err = json.Unmarshal(content, &e); err != nil {
-			return nil, err
-		}
-		// TODO: we may want to check that the returned entries actually match
-		// the module we asked about, so that the cache cannot be poisoned
-		entries = append(entries, e...)
-
-		if hs.cache != nil {
-			if err := hs.cache.WriteEntries(hs.dbName, p, e); err != nil {
-				return nil, err
+		} else if len(cached) != 0 {
+			var stale bool
+			for _, c := range cached {
+				if c.Modified.Before(lastModified) {
+					stale = true
+					break
+				}
+			}
+			if !stale {
+				return cached, nil
 			}
 		}
 	}
-	return entries, nil
+
+	resp, err := hs.c.Get(fmt.Sprintf("%s/%s.json", hs.url, module))
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode == http.StatusNotFound {
+		return nil, nil
+	}
+	// might want this to be a LimitedReader
+	content, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		return nil, err
+	}
+	var e []*osv.Entry
+	// TODO: we may want to check that the returned entries actually match
+	// the module we asked about, so that the cache cannot be poisoned
+	if err = json.Unmarshal(content, &e); err != nil {
+		return nil, err
+	}
+
+	if hs.cache != nil {
+		if err := hs.cache.WriteEntries(hs.dbName, module, e); err != nil {
+			return nil, err
+		}
+	}
+	return e, nil
 }
 
 type Client struct {
@@ -252,11 +239,11 @@
 	return c, nil
 }
 
-func (c *Client) Get(modules []string) ([]*osv.Entry, error) {
+func (c *Client) Get(module string) ([]*osv.Entry, error) {
 	var entries []*osv.Entry
 	// probably should be parallelized
 	for _, s := range c.sources {
-		e, err := s.Get(modules)
+		e, err := s.Get(module)
 		if err != nil {
 			return nil, err // be failure tolerant?
 		}
diff --git a/client/client_test.go b/client/client_test.go
index c8c04cd..af371c0 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -21,49 +21,38 @@
 	"golang.org/x/vulndb/osv"
 )
 
-var testVuln1 string = `[
-	{"ID":"ID1","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
+var testVuln string = `[
+	{"ID":"ID","Package":{"Name":"golang.org/example/one","Ecosystem":"go"}, "Summary":"",
 	 "Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.2.0"}]},
 	 "ecosystem_specific":{"Symbols":["some_symbol_1"]
 	}}]`
 
-var testVuln2 string = `[
-	{"ID":"ID2","Package":{"Name":"golang.org/example/two","Ecosystem":"go"}, "Summary":"",
-	 "Severity":2,"Affects":{"Ranges":[{"Type":"SEMVER","Introduced":"","Fixed":"v2.1.0"}]},
-	 "ecosystem_specific":{"Symbols":["some_symbol_2"]
-	}}]`
-
-// index containing timestamps for packages in testVuln1 and testVuln2.
+// index containing timestamps for package in testVuln.
 var index string = `{
-	"golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00",
-	"golang.org/example/two": "2019-02-05T09:00:00.31561157-07:00"
+	"golang.org/example/one": "2020-03-09T10:00:00.81362141-07:00"
 	}`
 
-func serveTestVuln1(w http.ResponseWriter, req *http.Request) {
-	fmt.Fprintf(w, testVuln1)
-}
-
-func serveTestVuln2(w http.ResponseWriter, req *http.Request) {
-	fmt.Fprintf(w, testVuln2)
+func serveTestVuln(w http.ResponseWriter, req *http.Request) {
+	fmt.Fprintf(w, testVuln)
 }
 
 func serveIndex(w http.ResponseWriter, req *http.Request) {
 	fmt.Fprintf(w, index)
 }
 
-// cachedTestVuln2 returns a function creating a local cache
-// for db with `dbName` with a version of testVuln2 where
+// cachedTestVuln returns a function creating a local cache
+// for db with `dbName` with a version of testVuln where
 // Summary="cached" and LastModified happened after entry
 // in the `index` for the same pkg.
-func cachedTestVuln2(dbName string) func() Cache {
+func cachedTestVuln(dbName string) func() Cache {
 	return func() Cache {
 		c := &fsCache{}
 		e := &osv.Entry{
-			ID:       "ID2",
+			ID:       "ID1",
 			Details:  "cached",
 			Modified: time.Now(),
 		}
-		c.WriteEntries(dbName, "golang.org/example/two", []*osv.Entry{e})
+		c.WriteEntries(dbName, "golang.org/example/one", []*osv.Entry{e})
 		return c
 	}
 }
@@ -81,10 +70,7 @@
 func localDB(t *testing.T) (string, error) {
 	dbName := t.TempDir()
 
-	if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln1); err != nil {
-		return "", err
-	}
-	if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "two.json", testVuln2); err != nil {
+	if err := createDirAndFile(path.Join(dbName, "/golang.org/example/"), "one.json", testVuln); err != nil {
 		return "", err
 	}
 	if err := createDirAndFile(path.Join(dbName, ""), "index.json", index); err != nil {
@@ -99,8 +85,7 @@
 	}
 
 	// Create a local http database.
-	http.HandleFunc("/golang.org/example/one.json", serveTestVuln1)
-	http.HandleFunc("/golang.org/example/two.json", serveTestVuln2)
+	http.HandleFunc("/golang.org/example/one.json", serveTestVuln)
 	http.HandleFunc("/index.json", serveIndex)
 
 	l, err := net.Listen("tcp", "127.0.0.1:")
@@ -121,20 +106,20 @@
 		name        string
 		source      string
 		createCache func() Cache
-		noVulns     int
-		summaries   map[string]string
+		// cache summary for testVuln
+		summary string
 	}{
 		// Test the http client without any cache.
-		{name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+		{name: "http-no-cache", source: "http://localhost:" + port, createCache: func() Cache { return nil }, summary: ""},
 		// Test the http client with empty cache.
-		{name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+		{name: "http-empty-cache", source: "http://localhost:" + port, createCache: func() Cache { return &fsCache{} }, summary: ""},
 		// Test the client with non-stale cache containing a version of testVuln2 where Summary="cached".
-		{name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln2("localhost"), noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": "cached"}},
+		{name: "http-cache", source: "http://localhost:" + port, createCache: cachedTestVuln("localhost"), summary: "cached"},
 		// Repeat the same for local file client.
-		{name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
-		{name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+		{name: "file-no-cache", source: "file://" + localDBName, createCache: func() Cache { return nil }, summary: ""},
+		{name: "file-empty-cache", source: "file://" + localDBName, createCache: func() Cache { return &fsCache{} }, summary: ""},
 		// Cache does not play a role in local file databases.
-		{name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln2(localDBName), noVulns: 2, summaries: map[string]string{"ID1": "", "ID2": ""}},
+		{name: "file-cache", source: "file://" + localDBName, createCache: cachedTestVuln(localDBName), summary: ""},
 	} {
 		// Create fresh cache location each time.
 		cacheRoot = t.TempDir()
@@ -144,18 +129,17 @@
 			t.Fatal(err)
 		}
 
-		vulns, err := client.Get([]string{"golang.org/example/one", "golang.org/example/two"})
+		vulns, err := client.Get("golang.org/example/one")
 		if err != nil {
 			t.Fatal(err)
 		}
-		if len(vulns) != test.noVulns {
-			t.Errorf("want %v vulns for %s; got %v", test.noVulns, test.name, len(vulns))
+
+		if len(vulns) != 1 {
+			t.Errorf("%s: want 1 vuln for golang.org/example/one; got %v", test.name, len(vulns))
 		}
 
-		for _, v := range vulns {
-			if s, ok := test.summaries[v.ID]; !ok || v.Details != s {
-				t.Errorf("want '%s' summary for vuln with id %v in %s; got '%s'", s, v.ID, test.name, v.Details)
-			}
+		if v := vulns[0]; v.Details != test.summary {
+			t.Errorf("%s: want '%s' summary for testVuln; got '%s'", test.name, test.summary, v.Details)
 		}
 	}
 }
@@ -177,11 +161,12 @@
 	defer ts.Close()
 
 	hs := &httpSource{url: ts.URL, c: new(http.Client)}
-	_, err := hs.Get([]string{"a", "b", "c"})
-	if err != nil {
-		t.Fatalf("unexpected error: %s", err)
+	for _, module := range []string{"a", "b", "c"} {
+		if _, err := hs.Get(module); err != nil {
+			t.Fatalf("unexpected error: %s", err)
+		}
 	}
-	expectedFetches := map[string]int{"/index.json": 1, "/a.json": 1, "/b.json": 1}
+	expectedFetches := map[string]int{"/index.json": 3, "/a.json": 1, "/b.json": 1}
 	if !reflect.DeepEqual(fetches, expectedFetches) {
 		t.Errorf("unexpected fetches, got %v, want %v", fetches, expectedFetches)
 	}