client: split client.go into file.go and http.go
Pure code in motion.
Change-Id: I5c22b48dc896b0f196969331f96db02e44664fe7
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/474221
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Julie Qiu <julieqiu@google.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
diff --git a/client/client.go b/client/client.go
index 6018da9..528ee2a 100644
--- a/client/client.go
+++ b/client/client.go
@@ -30,14 +30,10 @@
import (
"context"
- "encoding/json"
"fmt"
- "io"
"net/http"
"net/url"
"os"
- "path"
- "path/filepath"
"sort"
"strings"
"time"
@@ -78,72 +74,6 @@
unexported() // ensures that adding a method won't break users
}
-type localSource struct {
- dir string
-}
-
-func (*localSource) unexported() {}
-
-func (ls *localSource) GetByModule(ctx context.Context, modulePath string) (_ []*osv.Entry, err error) {
- defer derrors.Wrap(&err, "localSource.GetByModule(%q)", modulePath)
-
- index, err := ls.Index(ctx)
- if err != nil {
- return nil, err
- }
- // Query index first to be consistent with the way httpSource.GetByModule works.
- // Prevents opening and stating files on disk that don't need to be touched. Also
- // solves #56179.
- if _, present := index[modulePath]; !present {
- return nil, nil
- }
-
- epath, err := EscapeModulePath(modulePath)
- if err != nil {
- return nil, err
- }
- content, err := os.ReadFile(filepath.Join(ls.dir, epath+".json"))
- if os.IsNotExist(err) {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- var e []*osv.Entry
- if err = json.Unmarshal(content, &e); err != nil {
- return nil, err
- }
- return e, nil
-}
-
-func (ls *localSource) GetByID(_ context.Context, id string) (_ *osv.Entry, err error) {
- defer derrors.Wrap(&err, "GetByID(%q)", id)
- content, err := os.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, id+".json"))
- if os.IsNotExist(err) {
- return nil, nil
- } else if err != nil {
- return nil, err
- }
- var e osv.Entry
- if err = json.Unmarshal(content, &e); err != nil {
- return nil, err
- }
- return &e, nil
-}
-
-func (ls *localSource) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
- defer derrors.Wrap(&err, "localSource.GetByAlias(%q)", alias)
-
- aliasToIDs, err := localReadJSON[map[string][]string](ctx, ls, "aliases.json")
- if err != nil {
- return nil, err
- }
- ids := aliasToIDs[alias]
- if len(ids) == 0 {
- return nil, nil
- }
- return getByIDs(ctx, ls, ids)
-}
-
func getByIDs(ctx context.Context, client Client, ids []string) ([]*osv.Entry, error) {
var entries []*osv.Entry
for _, id := range ids {
@@ -156,166 +86,6 @@
return entries, nil
}
-func (ls *localSource) ListIDs(ctx context.Context) (_ []string, err error) {
- defer derrors.Wrap(&err, "ListIDs()")
-
- return localReadJSON[[]string](ctx, ls, filepath.Join(internal.IDDirectory, "index.json"))
-}
-
-func (ls *localSource) LastModifiedTime(context.Context) (_ time.Time, err error) {
- defer derrors.Wrap(&err, "LastModifiedTime()")
-
- // Assume that if anything changes, the index does.
- info, err := os.Stat(filepath.Join(ls.dir, "index.json"))
- if err != nil {
- return time.Time{}, err
- }
- return info.ModTime(), nil
-}
-
-func (ls *localSource) Index(ctx context.Context) (_ DBIndex, err error) {
- defer derrors.Wrap(&err, "Index()")
-
- return localReadJSON[DBIndex](ctx, ls, "index.json")
-}
-
-func localReadJSON[T any](_ context.Context, ls *localSource, relativePath string) (T, error) {
- var zero T
- content, err := os.ReadFile(filepath.Join(ls.dir, relativePath))
- if err != nil {
- return zero, err
- }
- var t T
- if err := json.Unmarshal(content, &t); err != nil {
- return zero, err
- }
- return t, nil
-}
-
-type httpSource struct {
- url string // the base URI of the source (without trailing "/"). e.g. https://vuln.golang.org
- c *http.Client
- cache Cache
- dbName string
-
- // indexCalls counts the number of times Index()
- // method has been called. httpCalls counts the
- // number of times GetByModule makes an http request
- // to the vuln db for a module path. Used for testing
- // privacy properties of httpSource.
- indexCalls int
- httpCalls int
-}
-
-func (hs *httpSource) Index(ctx context.Context) (_ DBIndex, err error) {
- hs.indexCalls++ // for testing privacy properties
- defer derrors.Wrap(&err, "Index()")
-
- var cachedIndex DBIndex
- var cachedIndexRetrieved *time.Time
-
- if hs.cache != nil {
- index, retrieved, err := hs.cache.ReadIndex(hs.dbName)
- if err != nil {
- return nil, err
- }
-
- cachedIndex = index
- if cachedIndex != nil {
- if time.Since(retrieved) < time.Hour*2 {
- return cachedIndex, nil
- }
-
- cachedIndexRetrieved = &retrieved
- }
- }
-
- req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/index.json", hs.url), nil)
- if err != nil {
- return nil, err
- }
- if cachedIndexRetrieved != nil {
- req.Header.Add("If-Modified-Since", cachedIndexRetrieved.Format(http.TimeFormat))
- }
- resp, err := hs.c.Do(req)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- if cachedIndexRetrieved != nil && resp.StatusCode == http.StatusNotModified {
- // If status has not been modified, this is equivalent to returning the
- // same index. We update the timestamp so the next cache index read does
- // not require a roundtrip to the server.
- if err = hs.cache.WriteIndex(hs.dbName, cachedIndex, time.Now()); err != nil {
- return nil, err
- }
- return cachedIndex, nil
- }
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
- }
- b, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, err
- }
- var index DBIndex
- if err = json.Unmarshal(b, &index); err != nil {
- return nil, err
- }
-
- if hs.cache != nil {
- if err = hs.cache.WriteIndex(hs.dbName, index, time.Now()); err != nil {
- return nil, err
- }
- }
-
- return index, nil
-}
-
-func (*httpSource) unexported() {}
-
-func (hs *httpSource) GetByModule(ctx context.Context, modulePath string) (_ []*osv.Entry, err error) {
- defer derrors.Wrap(&err, "httpSource.GetByModule(%q)", modulePath)
-
- index, err := hs.Index(ctx)
- if err != nil {
- return nil, err
- }
-
- lastModified, present := index[modulePath]
- if !present {
- return nil, nil
- }
-
- if hs.cache != nil {
- cached, err := hs.cache.ReadEntries(hs.dbName, modulePath)
- if err != nil {
- return nil, err
- }
- if len(cached) > 0 && !latestModifiedTime(cached).Before(lastModified) {
- return cached, nil
- }
- }
-
- epath, err := EscapeModulePath(modulePath)
- if err != nil {
- return nil, err
- }
- hs.httpCalls++ // for testing privacy properties
- entries, err := httpReadJSON[[]*osv.Entry](ctx, hs, epath+".json")
- if err != nil || entries == nil {
- return nil, err
- }
- // TODO: we may want to check that the returned entries actually match
- // the module we asked about, so that the cache cannot be poisoned
- if hs.cache != nil {
- if err := hs.cache.WriteEntries(hs.dbName, modulePath, entries); err != nil {
- return nil, err
- }
- }
- return entries, nil
-}
-
// Pseudo-module paths used for parts of the Go system.
// These are technically not valid module paths, so we
// mustn't pass them to module.EscapePath.
@@ -356,101 +126,6 @@
return t
}
-func (hs *httpSource) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
- defer derrors.Wrap(&err, "GetByID(%q)", id)
-
- return httpReadJSON[*osv.Entry](ctx, hs, fmt.Sprintf("%s/%s.json", internal.IDDirectory, id))
-}
-
-func (hs *httpSource) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
- defer derrors.Wrap(&err, "httpSource.GetByAlias(%q)", alias)
-
- aliasToIDs, err := httpReadJSON[map[string][]string](ctx, hs, "aliases.json")
- if err != nil {
- return nil, err
- }
- ids := aliasToIDs[alias]
- if len(ids) == 0 {
- return nil, nil
- }
- return getByIDs(ctx, hs, ids)
-}
-
-func (hs *httpSource) ListIDs(ctx context.Context) (_ []string, err error) {
- defer derrors.Wrap(&err, "ListIDs()")
-
- return httpReadJSON[[]string](ctx, hs, path.Join(internal.IDDirectory, "index.json"))
-}
-
-func httpReadJSON[T any](ctx context.Context, hs *httpSource, relativePath string) (T, error) {
- var zero T
- content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s", hs.url, relativePath))
- if err != nil {
- return zero, err
- }
- if len(content) == 0 {
- return zero, nil
- }
- var t T
- if err := json.Unmarshal(content, &t); err != nil {
- return zero, err
- }
- return t, nil
-}
-
-// This is the format for the last-modified header, as described at
-// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified.
-var lastModifiedFormat = "Mon, 2 Jan 2006 15:04:05 GMT"
-
-func (hs *httpSource) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
- defer derrors.Wrap(&err, "LastModifiedTime()")
-
- // Assume that if anything changes, the index does.
- url := fmt.Sprintf("%s/index.json", hs.url)
- req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
- if err != nil {
- return time.Time{}, err
- }
- resp, err := hs.c.Do(req)
- if err != nil {
- return time.Time{}, err
- }
- if resp.StatusCode != 200 {
- return time.Time{}, fmt.Errorf("got status code %d", resp.StatusCode)
- }
- h := resp.Header.Get("Last-Modified")
- return time.Parse(lastModifiedFormat, h)
-}
-
-func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
- if err != nil {
- return nil, err
- }
- resp, err := hs.c.Do(req)
- if err != nil {
- return nil, err
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusNotFound {
- return nil, nil
- }
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("got HTTP status %s", resp.Status)
- }
- // might want this to be a LimitedReader
- return io.ReadAll(resp.Body)
-}
-
-type client struct {
- sources []Client
-}
-
-type Options struct {
- HTTPClient *http.Client
- HTTPCache Cache
-}
-
func NewClient(sources []string, opts Options) (_ Client, err error) {
defer derrors.Wrap(&err, "NewClient(%v, opts)", sources)
c := &client{}
diff --git a/client/file.go b/client/file.go
new file mode 100644
index 0000000..e82136a
--- /dev/null
+++ b/client/file.go
@@ -0,0 +1,119 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package client
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "time"
+
+ "golang.org/x/vuln/internal"
+ "golang.org/x/vuln/internal/derrors"
+ "golang.org/x/vuln/osv"
+)
+
+type localSource struct {
+ dir string
+}
+
+func (*localSource) unexported() {}
+
+func (ls *localSource) GetByModule(ctx context.Context, modulePath string) (_ []*osv.Entry, err error) {
+ defer derrors.Wrap(&err, "localSource.GetByModule(%q)", modulePath)
+
+ index, err := ls.Index(ctx)
+ if err != nil {
+ return nil, err
+ }
+ // Query index first to be consistent with the way httpSource.GetByModule works.
+ // Prevents opening and stating files on disk that don't need to be touched. Also
+ // solves #56179.
+ if _, present := index[modulePath]; !present {
+ return nil, nil
+ }
+
+ epath, err := EscapeModulePath(modulePath)
+ if err != nil {
+ return nil, err
+ }
+ content, err := os.ReadFile(filepath.Join(ls.dir, epath+".json"))
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ var e []*osv.Entry
+ if err = json.Unmarshal(content, &e); err != nil {
+ return nil, err
+ }
+ return e, nil
+}
+
+func (ls *localSource) GetByID(_ context.Context, id string) (_ *osv.Entry, err error) {
+ defer derrors.Wrap(&err, "GetByID(%q)", id)
+ content, err := os.ReadFile(filepath.Join(ls.dir, internal.IDDirectory, id+".json"))
+ if os.IsNotExist(err) {
+ return nil, nil
+ } else if err != nil {
+ return nil, err
+ }
+ var e osv.Entry
+ if err = json.Unmarshal(content, &e); err != nil {
+ return nil, err
+ }
+ return &e, nil
+}
+
+func (ls *localSource) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
+ defer derrors.Wrap(&err, "localSource.GetByAlias(%q)", alias)
+
+ aliasToIDs, err := localReadJSON[map[string][]string](ctx, ls, "aliases.json")
+ if err != nil {
+ return nil, err
+ }
+ ids := aliasToIDs[alias]
+ if len(ids) == 0 {
+ return nil, nil
+ }
+ return getByIDs(ctx, ls, ids)
+}
+
+func (ls *localSource) ListIDs(ctx context.Context) (_ []string, err error) {
+ defer derrors.Wrap(&err, "ListIDs()")
+
+ return localReadJSON[[]string](ctx, ls, filepath.Join(internal.IDDirectory, "index.json"))
+}
+
+func (ls *localSource) LastModifiedTime(context.Context) (_ time.Time, err error) {
+ defer derrors.Wrap(&err, "LastModifiedTime()")
+
+ // Assume that if anything changes, the index does.
+ info, err := os.Stat(filepath.Join(ls.dir, "index.json"))
+ if err != nil {
+ return time.Time{}, err
+ }
+ return info.ModTime(), nil
+}
+
+func (ls *localSource) Index(ctx context.Context) (_ DBIndex, err error) {
+ defer derrors.Wrap(&err, "Index()")
+
+ return localReadJSON[DBIndex](ctx, ls, "index.json")
+}
+
+func localReadJSON[T any](_ context.Context, ls *localSource, relativePath string) (T, error) {
+ var zero T
+ content, err := os.ReadFile(filepath.Join(ls.dir, relativePath))
+ if err != nil {
+ return zero, err
+ }
+ var t T
+ if err := json.Unmarshal(content, &t); err != nil {
+ return zero, err
+ }
+ return t, nil
+}
diff --git a/client/http.go b/client/http.go
new file mode 100644
index 0000000..66e17e0
--- /dev/null
+++ b/client/http.go
@@ -0,0 +1,238 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package client
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "path"
+ "time"
+
+ "golang.org/x/vuln/internal"
+ "golang.org/x/vuln/internal/derrors"
+ "golang.org/x/vuln/osv"
+)
+
+type httpSource struct {
+ url string // the base URI of the source (without trailing "/"). e.g. https://vuln.golang.org
+ c *http.Client
+ cache Cache
+ dbName string
+
+ // indexCalls counts the number of times Index()
+ // method has been called. httpCalls counts the
+ // number of times GetByModule makes an http request
+ // to the vuln db for a module path. Used for testing
+ // privacy properties of httpSource.
+ indexCalls int
+ httpCalls int
+}
+
+func (hs *httpSource) Index(ctx context.Context) (_ DBIndex, err error) {
+ hs.indexCalls++ // for testing privacy properties
+ defer derrors.Wrap(&err, "Index()")
+
+ var cachedIndex DBIndex
+ var cachedIndexRetrieved *time.Time
+
+ if hs.cache != nil {
+ index, retrieved, err := hs.cache.ReadIndex(hs.dbName)
+ if err != nil {
+ return nil, err
+ }
+
+ cachedIndex = index
+ if cachedIndex != nil {
+ if time.Since(retrieved) < time.Hour*2 {
+ return cachedIndex, nil
+ }
+
+ cachedIndexRetrieved = &retrieved
+ }
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/index.json", hs.url), nil)
+ if err != nil {
+ return nil, err
+ }
+ if cachedIndexRetrieved != nil {
+ req.Header.Add("If-Modified-Since", cachedIndexRetrieved.Format(http.TimeFormat))
+ }
+ resp, err := hs.c.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if cachedIndexRetrieved != nil && resp.StatusCode == http.StatusNotModified {
+ // If status has not been modified, this is equivalent to returning the
+ // same index. We update the timestamp so the next cache index read does
+ // not require a roundtrip to the server.
+ if err = hs.cache.WriteIndex(hs.dbName, cachedIndex, time.Now()); err != nil {
+ return nil, err
+ }
+ return cachedIndex, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+ b, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ var index DBIndex
+ if err = json.Unmarshal(b, &index); err != nil {
+ return nil, err
+ }
+
+ if hs.cache != nil {
+ if err = hs.cache.WriteIndex(hs.dbName, index, time.Now()); err != nil {
+ return nil, err
+ }
+ }
+
+ return index, nil
+}
+
+func (*httpSource) unexported() {}
+
+func (hs *httpSource) GetByModule(ctx context.Context, modulePath string) (_ []*osv.Entry, err error) {
+ defer derrors.Wrap(&err, "httpSource.GetByModule(%q)", modulePath)
+
+ index, err := hs.Index(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ lastModified, present := index[modulePath]
+ if !present {
+ return nil, nil
+ }
+
+ if hs.cache != nil {
+ cached, err := hs.cache.ReadEntries(hs.dbName, modulePath)
+ if err != nil {
+ return nil, err
+ }
+ if len(cached) > 0 && !latestModifiedTime(cached).Before(lastModified) {
+ return cached, nil
+ }
+ }
+
+ epath, err := EscapeModulePath(modulePath)
+ if err != nil {
+ return nil, err
+ }
+ hs.httpCalls++ // for testing privacy properties
+ entries, err := httpReadJSON[[]*osv.Entry](ctx, hs, epath+".json")
+ if err != nil || entries == nil {
+ return nil, err
+ }
+ // TODO: we may want to check that the returned entries actually match
+ // the module we asked about, so that the cache cannot be poisoned
+ if hs.cache != nil {
+ if err := hs.cache.WriteEntries(hs.dbName, modulePath, entries); err != nil {
+ return nil, err
+ }
+ }
+ return entries, nil
+}
+
+func (hs *httpSource) GetByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
+ defer derrors.Wrap(&err, "GetByID(%q)", id)
+
+ return httpReadJSON[*osv.Entry](ctx, hs, fmt.Sprintf("%s/%s.json", internal.IDDirectory, id))
+}
+
+func (hs *httpSource) GetByAlias(ctx context.Context, alias string) (entries []*osv.Entry, err error) {
+ defer derrors.Wrap(&err, "httpSource.GetByAlias(%q)", alias)
+
+ aliasToIDs, err := httpReadJSON[map[string][]string](ctx, hs, "aliases.json")
+ if err != nil {
+ return nil, err
+ }
+ ids := aliasToIDs[alias]
+ if len(ids) == 0 {
+ return nil, nil
+ }
+ return getByIDs(ctx, hs, ids)
+}
+
+func (hs *httpSource) ListIDs(ctx context.Context) (_ []string, err error) {
+ defer derrors.Wrap(&err, "ListIDs()")
+
+ return httpReadJSON[[]string](ctx, hs, path.Join(internal.IDDirectory, "index.json"))
+}
+
+func httpReadJSON[T any](ctx context.Context, hs *httpSource, relativePath string) (T, error) {
+ var zero T
+ content, err := hs.readBody(ctx, fmt.Sprintf("%s/%s", hs.url, relativePath))
+ if err != nil {
+ return zero, err
+ }
+ if len(content) == 0 {
+ return zero, nil
+ }
+ var t T
+ if err := json.Unmarshal(content, &t); err != nil {
+ return zero, err
+ }
+ return t, nil
+}
+
+// This is the format for the last-modified header, as described at
+// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified.
+var lastModifiedFormat = "Mon, 2 Jan 2006 15:04:05 GMT"
+
+func (hs *httpSource) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
+ defer derrors.Wrap(&err, "LastModifiedTime()")
+
+ // Assume that if anything changes, the index does.
+ url := fmt.Sprintf("%s/index.json", hs.url)
+ req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
+ if err != nil {
+ return time.Time{}, err
+ }
+ resp, err := hs.c.Do(req)
+ if err != nil {
+ return time.Time{}, err
+ }
+ if resp.StatusCode != 200 {
+ return time.Time{}, fmt.Errorf("got status code %d", resp.StatusCode)
+ }
+ h := resp.Header.Get("Last-Modified")
+ return time.Parse(lastModifiedFormat, h)
+}
+
+func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := hs.c.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, nil
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("got HTTP status %s", resp.Status)
+ }
+ // might want this to be a LimitedReader
+ return io.ReadAll(resp.Body)
+}
+
+type client struct {
+ sources []Client
+}
+
+type Options struct {
+ HTTPClient *http.Client
+ HTTPCache Cache
+}