client: add ListIDs

Add a method to list all the IDs in the database.

Change-Id: Id4a3eca2abcee2a35de28a672e13ad613c6fcfef
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/357613
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/client/client.go b/client/client.go
index 74916b1..88c8333 100644
--- a/client/client.go
+++ b/client/client.go
@@ -40,6 +40,7 @@
 	"net/url"
 	"os"
 	"path/filepath"
+	"sort"
 	"strings"
 	"time"
 
@@ -56,6 +57,9 @@
 	// GetByID returns the entry with the given ID, or (nil, nil) if there isn't
 	// one.
 	GetByID(string) (*osv.Entry, error)
+
+	// ListIDs returns the IDs of all entries in the database.
+	ListIDs() ([]string, error)
 }
 
 type source interface {
@@ -95,6 +99,18 @@
 	return &e, nil
 }
 
+func (ls *localSource) ListIDs() ([]string, error) {
+	content, err := ioutil.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, "index.json"))
+	if err != nil {
+		return nil, err
+	}
+	var ids []string
+	if err := json.Unmarshal(content, &ids); err != nil {
+		return nil, err
+	}
+	return ids, nil
+}
+
 func (ls *localSource) Index() (osv.DBIndex, error) {
 	var index osv.DBIndex
 	b, err := ioutil.ReadFile(filepath.Join(ls.dir, "index.json"))
@@ -235,6 +251,18 @@
 	return &e, nil
 }
 
+func (hs *httpSource) ListIDs() ([]string, error) {
+	content, err := hs.readBody(fmt.Sprintf("%s/%s/index.json", hs.url, internal.IDDirectory))
+	if err != nil {
+		return nil, err
+	}
+	var ids []string
+	if err := json.Unmarshal(content, &ids); err != nil {
+		return nil, err
+	}
+	return ids, nil
+}
+
 func (hs *httpSource) readBody(url string) ([]byte, error) {
 	resp, err := hs.c.Get(url)
 	if err != nil {
@@ -313,3 +341,24 @@
 	}
 	return nil, nil
 }
+
+// ListIDs returns the union of the IDs from all sources,
+// sorted lexically.
+func (c *client) ListIDs() ([]string, error) {
+	idSet := map[string]bool{}
+	for _, s := range c.sources {
+		ids, err := s.ListIDs()
+		if err != nil {
+			return nil, err
+		}
+		for _, id := range ids {
+			idSet[id] = true
+		}
+	}
+	var ids []string
+	for id := range idSet {
+		ids = append(ids, id)
+	}
+	sort.Strings(ids)
+	return ids, nil
+}
diff --git a/client/client_test.go b/client/client_test.go
index d09e3ec..83e5218 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -7,6 +7,7 @@
 import (
 	"encoding/json"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net"
 	"net/http"
@@ -308,3 +309,49 @@
 		})
 	}
 }
+
+func TestListIDs(t *testing.T) {
+	if runtime.GOOS == "js" {
+		t.Skip("skipping test: no network on js")
+	}
+
+	http.HandleFunc("/ID/index.json", func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, `["ID"]`)
+	})
+
+	l, err := net.Listen("tcp", "127.0.0.1:")
+	if err != nil {
+		t.Fatalf("failed to listen on 127.0.0.1: %s", err)
+	}
+	_, port, _ := net.SplitHostPort(l.Addr().String())
+	go func() { http.Serve(l, nil) }()
+
+	// Create a local file database.
+	localDBName, err := localDB(t)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(localDBName)
+
+	want := []string{"ID"}
+	for _, test := range []struct {
+		name   string
+		source string
+	}{
+		{name: "http", source: "http://localhost:" + port},
+	} {
+		t.Run(test.name, func(t *testing.T) {
+			client, err := NewClient([]string{test.source}, Options{})
+			if err != nil {
+				t.Fatal(err)
+			}
+			got, err := client.ListIDs()
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !cmp.Equal(got, want) {
+				t.Errorf("got\n%+v\nwant\n%+v", got, want)
+			}
+		})
+	}
+}