client: add context arg
Add a context.Context argument to methods that might involve network
traffic.
Change-Id: Ib743a7b1a8c80d09d16c4529f5d8822726e8b122
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/365054
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/client/client.go b/client/client.go
index 21b71ec..7d994cf 100644
--- a/client/client.go
+++ b/client/client.go
@@ -33,6 +33,7 @@
package client
import (
+ "context"
"encoding/json"
"fmt"
"io/ioutil"
@@ -53,21 +54,21 @@
type Client interface {
// GetByModule returns the entries that affect the given module path.
// It returns (nil, nil) if there are none.
- GetByModule(string) ([]*osv.Entry, error)
+ GetByModule(context.Context, string) ([]*osv.Entry, error)
// GetByID returns the entry with the given ID, or (nil, nil) if there isn't
// one.
- GetByID(string) (*osv.Entry, error)
+ GetByID(context.Context, string) (*osv.Entry, error)
// ListIDs returns the IDs of all entries in the database.
- ListIDs() ([]string, error)
+ ListIDs(context.Context) ([]string, error)
unexported() // ensures that adding a method won't break users
}
type source interface {
Client
- Index() (osv.DBIndex, error)
+ Index(context.Context) (osv.DBIndex, error)
}
type localSource struct {
@@ -76,7 +77,7 @@
func (*localSource) unexported() {}
-func (ls *localSource) GetByModule(module string) (_ []*osv.Entry, err error) {
+func (ls *localSource) GetByModule(_ context.Context, module string) (_ []*osv.Entry, err error) {
defer derrors.Wrap(&err, "GetByModule(%q)", module)
content, err := ioutil.ReadFile(filepath.Join(ls.dir, module+".json"))
if os.IsNotExist(err) {
@@ -91,7 +92,7 @@
return e, nil
}
-func (ls *localSource) GetByID(id string) (_ *osv.Entry, err error) {
+func (ls *localSource) GetByID(_ context.Context, id string) (_ *osv.Entry, err error) {
defer derrors.Wrap(&err, "GetByID(%q)", id)
content, err := ioutil.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, id+".json"))
if os.IsNotExist(err) {
@@ -106,7 +107,7 @@
return &e, nil
}
-func (ls *localSource) ListIDs() (_ []string, err error) {
+func (ls *localSource) ListIDs(context.Context) (_ []string, err error) {
defer derrors.Wrap(&err, "ListIDs()")
content, err := ioutil.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, "index.json"))
if err != nil {
@@ -119,7 +120,7 @@
return ids, nil
}
-func (ls *localSource) Index() (_ osv.DBIndex, err error) {
+func (ls *localSource) Index(context.Context) (_ osv.DBIndex, err error) {
defer derrors.Wrap(&err, "Index()")
var index osv.DBIndex
b, err := ioutil.ReadFile(filepath.Join(ls.dir, "index.json"))
@@ -139,7 +140,7 @@
dbName string
}
-func (hs *httpSource) Index() (_ osv.DBIndex, err error) {
+func (hs *httpSource) Index(ctx context.Context) (_ osv.DBIndex, err error) {
defer derrors.Wrap(&err, "Index()")
var cachedIndex osv.DBIndex
@@ -161,7 +162,7 @@
}
}
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/index.json", hs.url), nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/index.json", hs.url), nil)
if err != nil {
return nil, err
}
@@ -205,10 +206,10 @@
func (*httpSource) unexported() {}
-func (hs *httpSource) GetByModule(module string) (_ []*osv.Entry, err error) {
+func (hs *httpSource) GetByModule(ctx context.Context, module string) (_ []*osv.Entry, err error) {
defer derrors.Wrap(&err, "GetByModule(%q)", module)
- index, err := hs.Index()
+ index, err := hs.Index(ctx)
if err != nil {
return nil, err
}
@@ -235,7 +236,7 @@
}
}
- content, err := hs.readBody(fmt.Sprintf("%s/%s.json", hs.url, module))
+ content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s.json", hs.url, module))
if err != nil || content == nil {
return nil, err
}
@@ -254,10 +255,10 @@
return e, nil
}
-func (hs *httpSource) GetByID(id string) (_ *osv.Entry, err error) {
+func (hs *httpSource) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
defer derrors.Wrap(&err, "GetByID(%q)", id)
- content, err := hs.readBody(fmt.Sprintf("%s/%s/%s.json", hs.url, internal.IDDirectory, id))
+ content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s/%s.json", hs.url, internal.IDDirectory, id))
if err != nil || content == nil {
return nil, err
}
@@ -268,10 +269,10 @@
return &e, nil
}
-func (hs *httpSource) ListIDs() (_ []string, err error) {
+func (hs *httpSource) ListIDs(ctx context.Context) (_ []string, err error) {
defer derrors.Wrap(&err, "ListIDs()")
- content, err := hs.readBody(fmt.Sprintf("%s/%s/index.json", hs.url, internal.IDDirectory))
+ content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s/index.json", hs.url, internal.IDDirectory))
if err != nil {
return nil, err
}
@@ -282,8 +283,12 @@
return ids, nil
}
-func (hs *httpSource) readBody(url string) ([]byte, error) {
- resp, err := hs.c.Get(url)
+func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := hs.c.Do(req)
if err != nil {
return nil, err
}
@@ -338,12 +343,12 @@
func (*client) unexported() {}
-func (c *client) GetByModule(module string) (_ []*osv.Entry, err error) {
+func (c *client) GetByModule(ctx context.Context, module string) (_ []*osv.Entry, err error) {
defer derrors.Wrap(&err, "GetByModule(%q)", module)
var entries []*osv.Entry
// probably should be parallelized
for _, s := range c.sources {
- e, err := s.GetByModule(module)
+ e, err := s.GetByModule(ctx, module)
if err != nil {
return nil, err // be failure tolerant?
}
@@ -352,10 +357,10 @@
return entries, nil
}
-func (c *client) GetByID(id string) (_ *osv.Entry, err error) {
+func (c *client) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
defer derrors.Wrap(&err, "GetByID(%q)", id)
for _, s := range c.sources {
- entry, err := s.GetByID(id)
+ entry, err := s.GetByID(ctx, id)
if err != nil {
return nil, err // be failure tolerant?
}
@@ -368,11 +373,11 @@
// ListIDs returns the union of the IDs from all sources,
// sorted lexically.
-func (c *client) ListIDs() (_ []string, err error) {
+func (c *client) ListIDs(ctx context.Context) (_ []string, err error) {
defer derrors.Wrap(&err, "ListIDs()")
idSet := map[string]bool{}
for _, s := range c.sources {
- ids, err := s.ListIDs()
+ ids, err := s.ListIDs(ctx)
if err != nil {
return nil, err
}
diff --git a/client/client_test.go b/client/client_test.go
index b959764..9b2ca26 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -5,6 +5,7 @@
package client
import (
+ "context"
"encoding/json"
"fmt"
"io"
@@ -157,7 +158,7 @@
if runtime.GOOS == "js" {
t.Skip("skipping test: no network on js")
}
-
+ ctx := context.Background()
// Create a local http database.
srv := newTestServer()
defer srv.Close()
@@ -198,7 +199,7 @@
t.Fatal(err)
}
- vulns, err := client.GetByModule("golang.org/example/one")
+ vulns, err := client.GetByModule(ctx, "golang.org/example/one")
if err != nil {
t.Fatal(err)
}
@@ -231,7 +232,7 @@
hs := &httpSource{url: ts.URL, c: new(http.Client)}
for _, module := range []string{"a", "b", "c"} {
- if _, err := hs.GetByModule(module); err != nil {
+ if _, err := hs.GetByModule(context.Background(), module); err != nil {
t.Fatalf("unexpected error: %s", err)
}
}
@@ -269,7 +270,7 @@
if err != nil {
t.Fatal(err)
}
- vulns, err := client.GetByModule("a")
+ vulns, err := client.GetByModule(context.Background(), "a")
if err != nil {
t.Fatal(err)
}
@@ -309,7 +310,7 @@
if err != nil {
t.Fatal(err)
}
- got, err := client.GetByID("ID")
+ got, err := client.GetByID(context.Background(), "ID")
if err != nil {
t.Fatal(err)
}
@@ -348,7 +349,7 @@
if err != nil {
t.Fatal(err)
}
- got, err := client.ListIDs()
+ got, err := client.ListIDs(context.Background())
if err != nil {
t.Fatal(err)
}