client: add ListIDs
Add a method to list all the IDs in the database.
Change-Id: Id4a3eca2abcee2a35de28a672e13ad613c6fcfef
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/357613
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/client/client.go b/client/client.go
index 74916b1..88c8333 100644
--- a/client/client.go
+++ b/client/client.go
@@ -40,6 +40,7 @@
"net/url"
"os"
"path/filepath"
+ "sort"
"strings"
"time"
@@ -56,6 +57,9 @@
// GetByID returns the entry with the given ID, or (nil, nil) if there isn't
// one.
GetByID(string) (*osv.Entry, error)
+
+ // ListIDs returns the IDs of all entries in the database.
+ ListIDs() ([]string, error)
}
type source interface {
@@ -95,6 +99,18 @@
return &e, nil
}
+func (ls *localSource) ListIDs() ([]string, error) {
+ content, err := ioutil.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, "index.json"))
+ if err != nil {
+ return nil, err
+ }
+ var ids []string
+ if err := json.Unmarshal(content, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
func (ls *localSource) Index() (osv.DBIndex, error) {
var index osv.DBIndex
b, err := ioutil.ReadFile(filepath.Join(ls.dir, "index.json"))
@@ -235,6 +251,18 @@
return &e, nil
}
+func (hs *httpSource) ListIDs() ([]string, error) {
+ content, err := hs.readBody(fmt.Sprintf("%s/%s/index.json", hs.url, internal.IDDirectory))
+ if err != nil {
+ return nil, err
+ }
+ var ids []string
+ if err := json.Unmarshal(content, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
func (hs *httpSource) readBody(url string) ([]byte, error) {
resp, err := hs.c.Get(url)
if err != nil {
@@ -313,3 +341,24 @@
}
return nil, nil
}
+
+// ListIDs returns the union of the IDs from all sources,
+// sorted lexically.
+func (c *client) ListIDs() ([]string, error) {
+ idSet := map[string]bool{}
+ for _, s := range c.sources {
+ ids, err := s.ListIDs()
+ if err != nil {
+ return nil, err
+ }
+ for _, id := range ids {
+ idSet[id] = true
+ }
+ }
+ var ids []string
+ for id := range idSet {
+ ids = append(ids, id)
+ }
+ sort.Strings(ids)
+ return ids, nil
+}
diff --git a/client/client_test.go b/client/client_test.go
index d09e3ec..83e5218 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -7,6 +7,7 @@
import (
"encoding/json"
"fmt"
+ "io"
"io/ioutil"
"net"
"net/http"
@@ -308,3 +309,49 @@
})
}
}
+
+func TestListIDs(t *testing.T) {
+ if runtime.GOOS == "js" {
+ t.Skip("skipping test: no network on js")
+ }
+
+ http.HandleFunc("/ID/index.json", func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, `["ID"]`)
+ })
+
+ l, err := net.Listen("tcp", "127.0.0.1:")
+ if err != nil {
+ t.Fatalf("failed to listen on 127.0.0.1: %s", err)
+ }
+ _, port, _ := net.SplitHostPort(l.Addr().String())
+ go func() { http.Serve(l, nil) }()
+
+ // Create a local file database.
+ localDBName, err := localDB(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(localDBName)
+
+ want := []string{"ID"}
+ for _, test := range []struct {
+ name string
+ source string
+ }{
+ {name: "http", source: "http://localhost:" + port},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ client, err := NewClient([]string{test.source}, Options{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := client.ListIDs()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !cmp.Equal(got, want) {
+ t.Errorf("got\n%+v\nwant\n%+v", got, want)
+ }
+ })
+ }
+}