client: add context arg

Add a context.Context argument to methods that might involve network
traffic.

Change-Id: Ib743a7b1a8c80d09d16c4529f5d8822726e8b122
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/365054
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/client/client.go b/client/client.go
index 21b71ec..7d994cf 100644
--- a/client/client.go
+++ b/client/client.go
@@ -33,6 +33,7 @@
 package client
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
@@ -53,21 +54,21 @@
 type Client interface {
 	// GetByModule returns the entries that affect the given module path.
 	// It returns (nil, nil) if there are none.
-	GetByModule(string) ([]*osv.Entry, error)
+	GetByModule(context.Context, string) ([]*osv.Entry, error)
 
 	// GetByID returns the entry with the given ID, or (nil, nil) if there isn't
 	// one.
-	GetByID(string) (*osv.Entry, error)
+	GetByID(context.Context, string) (*osv.Entry, error)
 
 	// ListIDs returns the IDs of all entries in the database.
-	ListIDs() ([]string, error)
+	ListIDs(context.Context) ([]string, error)
 
 	unexported() // ensures that adding a method won't break users
 }
 
 type source interface {
 	Client
-	Index() (osv.DBIndex, error)
+	Index(context.Context) (osv.DBIndex, error)
 }
 
 type localSource struct {
@@ -76,7 +77,7 @@
 
 func (*localSource) unexported() {}
 
-func (ls *localSource) GetByModule(module string) (_ []*osv.Entry, err error) {
+func (ls *localSource) GetByModule(_ context.Context, module string) (_ []*osv.Entry, err error) {
 	defer derrors.Wrap(&err, "GetByModule(%q)", module)
 	content, err := ioutil.ReadFile(filepath.Join(ls.dir, module+".json"))
 	if os.IsNotExist(err) {
@@ -91,7 +92,7 @@
 	return e, nil
 }
 
-func (ls *localSource) GetByID(id string) (_ *osv.Entry, err error) {
+func (ls *localSource) GetByID(_ context.Context, id string) (_ *osv.Entry, err error) {
 	defer derrors.Wrap(&err, "GetByID(%q)", id)
 	content, err := ioutil.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, id+".json"))
 	if os.IsNotExist(err) {
@@ -106,7 +107,7 @@
 	return &e, nil
 }
 
-func (ls *localSource) ListIDs() (_ []string, err error) {
+func (ls *localSource) ListIDs(context.Context) (_ []string, err error) {
 	defer derrors.Wrap(&err, "ListIDs()")
 	content, err := ioutil.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, "index.json"))
 	if err != nil {
@@ -119,7 +120,7 @@
 	return ids, nil
 }
 
-func (ls *localSource) Index() (_ osv.DBIndex, err error) {
+func (ls *localSource) Index(context.Context) (_ osv.DBIndex, err error) {
 	defer derrors.Wrap(&err, "Index()")
 	var index osv.DBIndex
 	b, err := ioutil.ReadFile(filepath.Join(ls.dir, "index.json"))
@@ -139,7 +140,7 @@
 	dbName string
 }
 
-func (hs *httpSource) Index() (_ osv.DBIndex, err error) {
+func (hs *httpSource) Index(ctx context.Context) (_ osv.DBIndex, err error) {
 	defer derrors.Wrap(&err, "Index()")
 
 	var cachedIndex osv.DBIndex
@@ -161,7 +162,7 @@
 		}
 	}
 
-	req, err := http.NewRequest("GET", fmt.Sprintf("%s/index.json", hs.url), nil)
+	req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/index.json", hs.url), nil)
 	if err != nil {
 		return nil, err
 	}
@@ -205,10 +206,10 @@
 
 func (*httpSource) unexported() {}
 
-func (hs *httpSource) GetByModule(module string) (_ []*osv.Entry, err error) {
+func (hs *httpSource) GetByModule(ctx context.Context, module string) (_ []*osv.Entry, err error) {
 	defer derrors.Wrap(&err, "GetByModule(%q)", module)
 
-	index, err := hs.Index()
+	index, err := hs.Index(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -235,7 +236,7 @@
 		}
 	}
 
-	content, err := hs.readBody(fmt.Sprintf("%s/%s.json", hs.url, module))
+	content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s.json", hs.url, module))
 	if err != nil || content == nil {
 		return nil, err
 	}
@@ -254,10 +255,10 @@
 	return e, nil
 }
 
-func (hs *httpSource) GetByID(id string) (_ *osv.Entry, err error) {
+func (hs *httpSource) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
 	defer derrors.Wrap(&err, "GetByID(%q)", id)
 
-	content, err := hs.readBody(fmt.Sprintf("%s/%s/%s.json", hs.url, internal.IDDirectory, id))
+	content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s/%s.json", hs.url, internal.IDDirectory, id))
 	if err != nil || content == nil {
 		return nil, err
 	}
@@ -268,10 +269,10 @@
 	return &e, nil
 }
 
-func (hs *httpSource) ListIDs() (_ []string, err error) {
+func (hs *httpSource) ListIDs(ctx context.Context) (_ []string, err error) {
 	defer derrors.Wrap(&err, "ListIDs()")
 
-	content, err := hs.readBody(fmt.Sprintf("%s/%s/index.json", hs.url, internal.IDDirectory))
+	content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s/index.json", hs.url, internal.IDDirectory))
 	if err != nil {
 		return nil, err
 	}
@@ -282,8 +283,12 @@
 	return ids, nil
 }
 
-func (hs *httpSource) readBody(url string) ([]byte, error) {
-	resp, err := hs.c.Get(url)
+func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+	resp, err := hs.c.Do(req)
 	if err != nil {
 		return nil, err
 	}
@@ -338,12 +343,12 @@
 
 func (*client) unexported() {}
 
-func (c *client) GetByModule(module string) (_ []*osv.Entry, err error) {
+func (c *client) GetByModule(ctx context.Context, module string) (_ []*osv.Entry, err error) {
 	defer derrors.Wrap(&err, "GetByModule(%q)", module)
 	var entries []*osv.Entry
 	// probably should be parallelized
 	for _, s := range c.sources {
-		e, err := s.GetByModule(module)
+		e, err := s.GetByModule(ctx, module)
 		if err != nil {
 			return nil, err // be failure tolerant?
 		}
@@ -352,10 +357,10 @@
 	return entries, nil
 }
 
-func (c *client) GetByID(id string) (_ *osv.Entry, err error) {
+func (c *client) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
 	defer derrors.Wrap(&err, "GetByID(%q)", id)
 	for _, s := range c.sources {
-		entry, err := s.GetByID(id)
+		entry, err := s.GetByID(ctx, id)
 		if err != nil {
 			return nil, err // be failure tolerant?
 		}
@@ -368,11 +373,11 @@
 
 // ListIDs returns the union of the IDs from all sources,
 // sorted lexically.
-func (c *client) ListIDs() (_ []string, err error) {
+func (c *client) ListIDs(ctx context.Context) (_ []string, err error) {
 	defer derrors.Wrap(&err, "ListIDs()")
 	idSet := map[string]bool{}
 	for _, s := range c.sources {
-		ids, err := s.ListIDs()
+		ids, err := s.ListIDs(ctx)
 		if err != nil {
 			return nil, err
 		}
diff --git a/client/client_test.go b/client/client_test.go
index b959764..9b2ca26 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -5,6 +5,7 @@
 package client
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -157,7 +158,7 @@
 	if runtime.GOOS == "js" {
 		t.Skip("skipping test: no network on js")
 	}
-
+	ctx := context.Background()
 	// Create a local http database.
 	srv := newTestServer()
 	defer srv.Close()
@@ -198,7 +199,7 @@
 			t.Fatal(err)
 		}
 
-		vulns, err := client.GetByModule("golang.org/example/one")
+		vulns, err := client.GetByModule(ctx, "golang.org/example/one")
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -231,7 +232,7 @@
 
 	hs := &httpSource{url: ts.URL, c: new(http.Client)}
 	for _, module := range []string{"a", "b", "c"} {
-		if _, err := hs.GetByModule(module); err != nil {
+		if _, err := hs.GetByModule(context.Background(), module); err != nil {
 			t.Fatalf("unexpected error: %s", err)
 		}
 	}
@@ -269,7 +270,7 @@
 	if err != nil {
 		t.Fatal(err)
 	}
-	vulns, err := client.GetByModule("a")
+	vulns, err := client.GetByModule(context.Background(), "a")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -309,7 +310,7 @@
 			if err != nil {
 				t.Fatal(err)
 			}
-			got, err := client.GetByID("ID")
+			got, err := client.GetByID(context.Background(), "ID")
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -348,7 +349,7 @@
 			if err != nil {
 				t.Fatal(err)
 			}
-			got, err := client.ListIDs()
+			got, err := client.ListIDs(context.Background())
 			if err != nil {
 				t.Fatal(err)
 			}