client: add LastModifiedTime

Add a LastModifiedTime method to Client so that tools that poll the DB
can easily determine whether it has changed.

Fixes golang/go#51735

Change-Id: Idb07ed924ec256234f9eba7467429ab2fd1f2430
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/394874
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/client/client.go b/client/client.go
index 6b3ef4b..592c9d3 100644
--- a/client/client.go
+++ b/client/client.go
@@ -67,6 +67,11 @@
 	// ListIDs returns the IDs of all entries in the database.
 	ListIDs(context.Context) ([]string, error)
 
+	// LastModifiedTime returns the time that the database was last modified.
+	// It can be used by tools that periodically check for vulnerabilities
+	// to avoid repeating work.
+	LastModifiedTime(context.Context) (time.Time, error)
+
 	unexported() // ensures that adding a method won't break users
 }
 
@@ -124,6 +129,17 @@
 	return ids, nil
 }
 
+func (ls *localSource) LastModifiedTime(context.Context) (_ time.Time, err error) {
+	defer derrors.Wrap(&err, "LastModifiedTime()")
+
+	// Assume that if anything changes, the index does.
+	info, err := os.Stat(filepath.Join(ls.dir, "index.json"))
+	if err != nil {
+		return time.Time{}, err
+	}
+	return info.ModTime(), nil
+}
+
 func (ls *localSource) Index(context.Context) (_ DBIndex, err error) {
 	defer derrors.Wrap(&err, "Index()")
 	var index DBIndex
@@ -287,6 +303,30 @@
 	return ids, nil
 }
 
+// This is the format for the last-modified header, as described at
+// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified.
+var lastModifiedFormat = "Mon, 2 Jan 2006 15:04:05 GMT"
+
+func (hs *httpSource) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
+	defer derrors.Wrap(&err, "LastModifiedTime()")
+
+	// Assume that if anything changes, the index does.
+	url := fmt.Sprintf("%s/index.json", hs.url)
+	req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
+	if err != nil {
+		return time.Time{}, err
+	}
+	resp, err := hs.c.Do(req)
+	if err != nil {
+		return time.Time{}, err
+	}
+	if resp.StatusCode != 200 {
+		return time.Time{}, fmt.Errorf("got status code %d", resp.StatusCode)
+	}
+	h := resp.Header.Get("Last-Modified")
+	return time.Parse(lastModifiedFormat, h)
+}
+
 func (hs *httpSource) readBody(ctx context.Context, url string) ([]byte, error) {
 	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
 	if err != nil {
@@ -396,3 +436,19 @@
 	sort.Strings(ids)
 	return ids, nil
 }
+
+// LastModifiedTime returns the latest modified time of all the sources.
+func (c *client) LastModifiedTime(ctx context.Context) (_ time.Time, err error) {
+	defer derrors.Wrap(&err, "LastModifiedTime()")
+	var lmt time.Time
+	for _, s := range c.sources {
+		t, err := s.LastModifiedTime(ctx)
+		if err != nil {
+			return time.Time{}, err
+		}
+		if t.After(lmt) {
+			lmt = t
+		}
+	}
+	return lmt, nil
+}
diff --git a/client/client_test.go b/client/client_test.go
index cb45cf1..e7bc1d1 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -140,6 +140,7 @@
 func newTestServer() *httptest.Server {
 	dataHandler := func(data string) http.HandlerFunc {
 		return func(w http.ResponseWriter, _ *http.Request) {
+			w.Header().Set("Last-Modified", time.Now().In(time.UTC).Format(lastModifiedFormat))
 			io.WriteString(w, data)
 		}
 	}
@@ -359,3 +360,43 @@
 		})
 	}
 }
+
+func TestLastModifiedTime(t *testing.T) {
+	if runtime.GOOS == "js" {
+		t.Skip("skipping test: no network on js")
+	}
+
+	srv := newTestServer()
+	defer srv.Close()
+
+	// Create a local file database.
+	localDBName, err := localDB(t)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(localDBName)
+
+	for _, test := range []struct {
+		name   string
+		source string
+	}{
+		{name: "http", source: srv.URL},
+		{name: "file", source: "file://" + localDBName},
+	} {
+		t.Run(test.name, func(t *testing.T) {
+			client, err := NewClient([]string{test.source}, Options{})
+			if err != nil {
+				t.Fatal(err)
+			}
+			got, err := client.LastModifiedTime(context.Background())
+			if err != nil {
+				t.Fatal(err)
+			}
+			// Time should be a little before now.
+			diff := time.Since(got)
+			if diff < 0 || diff > 10*time.Second {
+				t.Errorf("got difference from now of %s, wanted something positive under 10s", diff)
+			}
+		})
+	}
+}