client: add LastModifiedTime
Add a LastModifiedTime method to Client so that tools that poll the DB
can easily determine whether it has changed.
Fixes golang/go#51735
Change-Id: Idb07ed924ec256234f9eba7467429ab2fd1f2430
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/394874
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/client/client.go b/client/client.go
index 6b3ef4b..592c9d3 100644
--- a/client/client.go
+++ b/client/client.go
@@ -67,6 +67,11 @@
// ListIDs returns the IDs of all entries in the database.
ListIDs(context.Context) ([]string, error)
+ // LastModifiedTime returns the time that the database was last modified.
+ // It can be used by tools that periodically check for vulnerabilities
+ // to avoid repeating work.
+ LastModifiedTime(context.Context) (time.Time, error)
+
unexported() // ensures that adding a method won't break users
}
@@ -124,6 +129,17 @@
return ids, nil
}
+func (ls *localSource) LastModifiedTime(context.Context) (_ time.Time, err error) {
+ defer derrors.Wrap(&err, "LastModifiedTime()")
+
+ // Assume that if anything changes, the index does.
+ info, err := os.Stat(filepath.Join(ls.dir, "index.json"))
+ if err != nil {
+ return time.Time{}, err
+ }
+ return info.ModTime(), nil
+}
+
func (ls *localSource) Index(context.Context) (_ DBIndex, err error) {
defer derrors.Wrap(&err, "Index()")
var index DBIndex
@@ -287,6 +303,30 @@
return ids, nil
}
+// This is the format for the last-modified header, as described at
+// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified.
+var lastModifiedFormat = "Mon, 2 Jan 2006 15:04:05 GMT"
+
+func (hs *httpSource) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
+ defer derrors.Wrap(&err, "LastModifiedTime()")
+
+ // Assume that if anything changes, the index does.
+ url := fmt.Sprintf("%s/index.json", hs.url)
+ req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
+ if err != nil {
+ return time.Time{}, err
+ }
+ resp, err := hs.c.Do(req)
+ if err != nil {
+ return time.Time{}, err
+ }
+ if resp.StatusCode != 200 {
+ return time.Time{}, fmt.Errorf("got status code %d", resp.StatusCode)
+ }
+ h := resp.Header.Get("Last-Modified")
+ return time.Parse(lastModifiedFormat, h)
+}
+
func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
@@ -396,3 +436,19 @@
sort.Strings(ids)
return ids, nil
}
+
+// LastModifiedTime returns the latest modified time of all the sources.
+func (c *client) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
+ defer derrors.Wrap(&err, "LastModifiedTime()")
+ var lmt time.Time
+ for _, s := range c.sources {
+ t, err := s.LastModifiedTime(ctx)
+ if err != nil {
+ return time.Time{}, err
+ }
+ if t.After(lmt) {
+ lmt = t
+ }
+ }
+ return lmt, nil
+}
diff --git a/client/client_test.go b/client/client_test.go
index cb45cf1..e7bc1d1 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -140,6 +140,7 @@
func newTestServer() *httptest.Server {
dataHandler := func(data string) http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Last-Modified", time.Now().In(time.UTC).Format(lastModifiedFormat))
io.WriteString(w, data)
}
}
@@ -359,3 +360,43 @@
})
}
}
+
+func TestLastModifiedTime(t *testing.T) {
+ if runtime.GOOS == "js" {
+ t.Skip("skipping test: no network on js")
+ }
+
+ srv := newTestServer()
+ defer srv.Close()
+
+ // Create a local file database.
+ localDBName, err := localDB(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(localDBName)
+
+ for _, test := range []struct {
+ name string
+ source string
+ }{
+ {name: "http", source: srv.URL},
+ {name: "file", source: "file://" + localDBName},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ client, err := NewClient([]string{test.source}, Options{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := client.LastModifiedTime(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Time should be a little before now.
+ diff := time.Since(got)
+ if diff < 0 || diff > 10*time.Second {
+ t.Errorf("got difference from now of %s, wanted something positive under 10s", diff)
+ }
+ })
+ }
+}