internal/{govulncheck,worker}: use Firestore for WorkStates

Store and retrieve work states from Firestore.

This should be significantly faster than reading them from BigQuery,
so remove the per-process cache.

Change-Id: I76c3ca2aa28e9aa7a222cee58b7f43cf9d7386a8
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/551756
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/internal/fstore/fstore.go b/internal/fstore/fstore.go
index 1c779e4..0708fc6 100644
--- a/internal/fstore/fstore.go
+++ b/internal/fstore/fstore.go
@@ -13,6 +13,8 @@
 
 	"cloud.google.com/go/firestore"
 	"golang.org/x/pkgsite-metrics/internal/derrors"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 const namespaceCollection = "Namespaces"
@@ -62,16 +64,40 @@
 	defer derrors.Wrap(&err, "fstore.Get(%q)", dr.Path)
 	docsnap, err := dr.Get(ctx)
 	if err != nil {
-		return nil, err
+		return nil, convertError(err)
 	}
 	return Decode[T](docsnap)
 }
 
+// Set sets the DocumentRef to the value.
+func Set[T any](ctx context.Context, dr *firestore.DocumentRef, value *T) (err error) {
+	defer derrors.Wrap(&err, "firestore.Set(%q)", dr.Path)
+	_, err = dr.Set(ctx, value)
+	return convertError(err)
+}
+
 // Decode decodes a DocumentSnapshot into a value of type T.
 func Decode[T any](ds *firestore.DocumentSnapshot) (*T, error) {
 	var t T
 	if err := ds.DataTo(&t); err != nil {
-		return nil, err
+		return nil, convertError(err)
 	}
 	return &t, nil
 }
+
+// convertError converts err into one of this module's error kinds
+// if possible.
+func convertError(err error) error {
+	serr, ok := status.FromError(err)
+	if !ok {
+		return err
+	}
+	switch serr.Code() {
+	case codes.NotFound:
+		return derrors.NotFound
+	case codes.InvalidArgument:
+		return derrors.InvalidArgument
+	default:
+		return err
+	}
+}
diff --git a/internal/govulncheck/govulncheck.go b/internal/govulncheck/govulncheck.go
index b2c005b..fdb346a 100644
--- a/internal/govulncheck/govulncheck.go
+++ b/internal/govulncheck/govulncheck.go
@@ -11,8 +11,8 @@
 	"context"
 	"encoding/json"
 	"errors"
-	"fmt"
 	"net/http"
+	"net/url"
 	"os/exec"
 	"path/filepath"
 	"runtime"
@@ -23,7 +23,9 @@
 
 	"golang.org/x/pkgsite-metrics/internal/bigquery"
 	"golang.org/x/pkgsite-metrics/internal/derrors"
+	"golang.org/x/pkgsite-metrics/internal/fstore"
 	"golang.org/x/pkgsite-metrics/internal/govulncheckapi"
+	"golang.org/x/pkgsite-metrics/internal/log"
 	"golang.org/x/pkgsite-metrics/internal/scan"
 )
 
@@ -149,6 +151,14 @@
 	Vulns              []*Vuln        `bigquery:"vulns"`
 }
 
+// WorkState returns a WorkState for the Result.
+func (r *Result) WorkState() *WorkState {
+	return &WorkState{
+		WorkVersion:   &r.WorkVersion,
+		ErrorCategory: r.ErrorCategory,
+	}
+}
+
 // WorkVersion contains information that can be used to avoid duplicate work.
 // Given two WorkVersion values v1 and v2 for the same module path and version,
 // if v1.Equal(v2) then it is not necessary to scan the module.
@@ -216,35 +226,6 @@
 	ErrorCategory string
 }
 
-// ReadWorkState reads the most recent work version for module_path@version
-// in the govulncheck table together with its accompanying error category.
-func ReadWorkState(ctx context.Context, c *bigquery.Client, module_path, version string) (ws *WorkState, err error) {
-	defer derrors.Wrap(&err, "ReadWorkState")
-
-	const qf = `
-                SELECT go_version, worker_version, schema_version, vulndb_last_modified, error_category
-                FROM %s WHERE module_path="%s" AND version="%s" ORDER BY created_at DESC LIMIT 1
-        `
-	query := fmt.Sprintf(qf, "`"+c.FullTableName(TableName)+"`", module_path, version)
-	iter, err := c.Query(ctx, query)
-	if err != nil {
-		return nil, err
-	}
-
-	err = bigquery.ForEachRow(iter, func(r *Result) bool {
-		// This should be reachable at most once.
-		ws = &WorkState{
-			WorkVersion:   &r.WorkVersion,
-			ErrorCategory: r.ErrorCategory,
-		}
-		return true
-	})
-	if err != nil {
-		return nil, err
-	}
-	return ws, nil
-}
-
 // ScanStats contains monitoring information for a govulncheck run.
 type ScanStats struct {
 	// ScanSeconds is the amount of time a scan took to run, in seconds.
@@ -342,3 +323,36 @@
 var getMemoryUsage = func(c *exec.Cmd) uint64 {
 	return 0
 }
+
+const collName = "GovulncheckWorkStates"
+
+// SetWorkState writes the work state for modulePath@version.
+func SetWorkState(ctx context.Context, ns *fstore.Namespace, modulePath, version string, ws *WorkState) (err error) {
+	defer func() {
+		log.Debugf(ctx, "SetWorkState(%s@%s, %+v) => %v", modulePath, version, ws, err)
+	}()
+	dr := ns.Collection(collName).Doc(docName(modulePath, version))
+	return fstore.Set[WorkState](ctx, dr, ws)
+}
+
+// GetWorkState reads the work state for modulePath@version.
+// If there is none, it returns (nil, nil).
+func GetWorkState(ctx context.Context, ns *fstore.Namespace, modulePath, version string) (ws *WorkState, err error) {
+	defer func() {
+		log.Debugf(ctx, "GetWorkState(%s@%s) => (%+v, %v)", modulePath, version, ws, err)
+	}()
+
+	defer derrors.Wrap(&err, "ReadWorkState(%q, %q)", modulePath, version)
+	dr := ns.Collection(collName).Doc(docName(modulePath, version))
+	ws, err = fstore.Get[WorkState](ctx, dr)
+	if errors.Is(err, derrors.NotFound) {
+		return nil, nil
+	}
+	return ws, err
+}
+
+// docName returns a valid Firestore document name for the given module path and version.
+// It escapes slashes, since Firestore treats them specially.
+func docName(modulePath, version string) string {
+	return url.PathEscape(modulePath + "@" + version)
+}
diff --git a/internal/govulncheck/govulncheck_test.go b/internal/govulncheck/govulncheck_test.go
index 00519d5..6c76fda 100644
--- a/internal/govulncheck/govulncheck_test.go
+++ b/internal/govulncheck/govulncheck_test.go
@@ -13,6 +13,7 @@
 	bq "cloud.google.com/go/bigquery"
 	"github.com/google/go-cmp/cmp"
 	"golang.org/x/pkgsite-metrics/internal/bigquery"
+	"golang.org/x/pkgsite-metrics/internal/fstore"
 	"golang.org/x/pkgsite-metrics/internal/govulncheckapi"
 	test "golang.org/x/pkgsite-metrics/internal/testing"
 	"google.golang.org/api/iterator"
@@ -114,12 +115,8 @@
 	defer func() { must(client.Table(TableName).Delete(ctx)) }()
 
 	tm := time.Date(2022, 7, 21, 0, 0, 0, 0, time.UTC)
-	row := &Result{
-		ModulePath:  "m",
-		Version:     "v",
-		SortVersion: "sv",
-		ImportedBy:  10,
-		WorkVersion: WorkVersion{
+	ws := &WorkState{
+		WorkVersion: &WorkVersion{
 			GoVersion:          "go1.19.6",
 			WorkerVersion:      "1",
 			SchemaVersion:      "s",
@@ -127,6 +124,14 @@
 		},
 		ErrorCategory: "SOME ERROR",
 	}
+	row := &Result{
+		ModulePath:    "m",
+		Version:       "v",
+		SortVersion:   "sv",
+		ImportedBy:    10,
+		WorkVersion:   *ws.WorkVersion,
+		ErrorCategory: ws.ErrorCategory,
+	}
 
 	t.Run("upload", func(t *testing.T) {
 		must(client.Upload(ctx, TableName, row))
@@ -147,24 +152,26 @@
 			t.Errorf("mismatch (-want, +got):\n%s", diff)
 		}
 	})
-	t.Run("work versions", func(t *testing.T) {
-		ws, err := ReadWorkState(ctx, client, "m", "v")
+	t.Run("work states", func(t *testing.T) {
+		ns, err := fstore.OpenNamespace(ctx, projectID, "testing")
 		if err != nil {
 			t.Fatal(err)
 		}
-		if ws == nil {
-			t.Fatal("got nil, wanted work state")
+		if err := SetWorkState(ctx, ns, "example.com/mod", "v1.0.0", ws); err != nil {
+			t.Fatal(err)
 		}
-		wgot := ws.WorkVersion
-		if wgot == nil {
-			t.Fatal("got nil, wanted work version")
+		got, err := GetWorkState(ctx, ns, "example.com/mod", "v1.0.0")
+		if err != nil {
+			t.Fatal(err)
 		}
-		if want := &row.WorkVersion; !wgot.Equal(want) {
-			t.Errorf("got %+v, want %+v", wgot, want)
+		if !cmp.Equal(got, ws) {
+			t.Errorf("got %+v\nwant %+v", got, ws)
 		}
-		egot := ws.ErrorCategory
-		if want := row.ErrorCategory; want != egot {
-			t.Errorf("got %+v, want %+v", egot, want)
+
+		// GetWorkState returns nil if the WorkState doesn't exist.
+		got, err = GetWorkState(ctx, ns, "example.com/mod", "v1.2.3")
+		if got != nil || err != nil {
+			t.Errorf("got (%v, %v), want (nil, nil)", got, err)
 		}
 	})
 }
diff --git a/internal/worker/govulncheck.go b/internal/worker/govulncheck.go
index 3f4edc9..c6bf92a 100644
--- a/internal/worker/govulncheck.go
+++ b/internal/worker/govulncheck.go
@@ -19,15 +19,11 @@
 
 type GovulncheckServer struct {
 	*Server
-	storedWorkStates map[[2]string]*govulncheck.WorkState
-	workVersion      *govulncheck.WorkVersion
+	workVersion *govulncheck.WorkVersion
 }
 
 func newGovulncheckServer(s *Server) *GovulncheckServer {
-	return &GovulncheckServer{
-		Server:           s,
-		storedWorkStates: make(map[[2]string]*govulncheck.WorkState),
-	}
+	return &GovulncheckServer{Server: s}
 }
 
 func (h *GovulncheckServer) getWorkVersion(ctx context.Context) (_ *govulncheck.WorkVersion, err error) {
diff --git a/internal/worker/govulncheck_scan.go b/internal/worker/govulncheck_scan.go
index c8e48f6..2f41a17 100644
--- a/internal/worker/govulncheck_scan.go
+++ b/internal/worker/govulncheck_scan.go
@@ -17,6 +17,7 @@
 	"golang.org/x/exp/event"
 	"golang.org/x/pkgsite-metrics/internal/bigquery"
 	"golang.org/x/pkgsite-metrics/internal/derrors"
+	"golang.org/x/pkgsite-metrics/internal/fstore"
 	"golang.org/x/pkgsite-metrics/internal/govulncheck"
 	"golang.org/x/pkgsite-metrics/internal/govulncheckapi"
 	"golang.org/x/pkgsite-metrics/internal/log"
@@ -105,7 +106,7 @@
 	if sreq.Insecure {
 		scanner.insecure = sreq.Insecure
 	}
-	skip, err = h.canSkip(ctx, sreq, scanner)
+	skip, err = scanner.canSkip(ctx, sreq, h.fsNamespace)
 	if err != nil {
 		return err
 	}
@@ -113,27 +114,40 @@
 		log.Infof(ctx, "skipping (work version unchanged or unrecoverable error): %s@%s", sreq.Module, sreq.Version)
 		return nil
 	}
-
-	return scanner.ScanModule(ctx, w, sreq)
+	workState, err := scanner.ScanModule(ctx, w, sreq)
+	if err != nil {
+		return err
+	}
+	if workState == nil {
+		return nil
+	}
+	// We can't upload the row to bigquery and write the WorkState to Firestore atomically.
+	// But that's OK: if we fail before writing the WorkState, then we'll just re-do the scan
+	// the next time.
+	if err := govulncheck.SetWorkState(ctx, h.fsNamespace, sreq.Module, sreq.Version, workState); err != nil {
+		// Don't fail if there's an error, because we'd just re-run the task.
+		log.Errorf(ctx, err, "SetWorkState")
+	}
+	return nil
 }
 
-func (h *GovulncheckServer) canSkip(ctx context.Context, sreq *govulncheck.Request, scanner *scanner) (bool, error) {
-	if err := h.readGovulncheckWorkState(ctx, sreq.Module, sreq.Version); err != nil {
+func (s *scanner) canSkip(ctx context.Context, sreq *govulncheck.Request, fsn *fstore.Namespace) (bool, error) {
+	ws, err := govulncheck.GetWorkState(ctx, fsn, sreq.Module, sreq.Version)
+	if err != nil {
 		return false, err
 	}
-	wve := h.storedWorkStates[[2]string{sreq.Module, sreq.Version}]
-	if wve == nil {
-		// sreq.Module@sreq.Version have not been analyzed before.
+	if ws == nil {
+		// Not scanned before.
 		return false, nil
 	}
-
-	if scanner.workVersion.Equal(wve.WorkVersion) {
+	log.Infof(ctx, "read work version for %s@%s", sreq.Module, sreq.Version)
+	if s.workVersion.Equal(ws.WorkVersion) {
 		// If the work version has not changed, skip analyzing the module
 		return true, nil
 	}
 	// Otherwise, skip if the error is not recoverable. The version of the
 	// module has not changed, so we'll get the same error anyhow.
-	return unrecoverableError(wve.ErrorCategory), nil
+	return unrecoverableError(ws.ErrorCategory), nil
 }
 
 // unrecoverableError returns true iff errorCategory encodes that
@@ -148,27 +162,6 @@
 	}
 }
 
-func (h *GovulncheckServer) readGovulncheckWorkState(ctx context.Context, module_path, version string) error {
-	h.mu.Lock()
-	defer h.mu.Unlock()
-	// Don't read work state for module_path@version if an entry in the cache already exists.
-	if _, ok := h.storedWorkStates[[2]string{module_path, version}]; ok {
-		return nil
-	}
-	if h.bqClient == nil {
-		return nil
-	}
-	ws, err := govulncheck.ReadWorkState(ctx, h.bqClient, module_path, version)
-	if err != nil {
-		return err
-	}
-	if ws != nil {
-		h.storedWorkStates[[2]string{module_path, version}] = ws
-	}
-	log.Infof(ctx, "read work version for %s@%s", module_path, version)
-	return nil
-}
-
 // A scanner holds state for scanning modules.
 type scanner struct {
 	proxyClient *proxy.Client
@@ -306,9 +299,10 @@
 	return row
 }
 
-func (s *scanner) ScanModule(ctx context.Context, w http.ResponseWriter, sreq *govulncheck.Request) error {
+// ScanModule scans the module in the request. It returns the WorkState for the result.
+func (s *scanner) ScanModule(ctx context.Context, w http.ResponseWriter, sreq *govulncheck.Request) (*govulncheck.WorkState, error) {
 	if sreq.Module == "std" {
-		return nil // ignore the standard library
+		return nil, nil // ignore the standard library
 	}
 	row := &govulncheck.Result{
 		ModulePath:  sreq.Module,
@@ -326,14 +320,19 @@
 		log.Infof(ctx, "proxy error: %s@%s %v", sreq.Path(), sreq.Version, err)
 		row.AddError(fmt.Errorf("%v: %w", err, derrors.ProxyError))
 		// TODO: should we also make a copy for imports mode?
-		return writeResult(ctx, sreq.Serve, w, s.bqClient, govulncheck.TableName, row)
+		if err := writeResult(ctx, sreq.Serve, w, s.bqClient, govulncheck.TableName, row); err != nil {
+			return nil, err
+		}
+		return row.WorkState(), nil
 	}
 	row.Version = info.Version
 	row.SortVersion = version.ForSorting(row.Version)
 	row.CommitTime = info.Time
 
 	if sreq.Mode == ModeCompare {
-		return s.CompareModule(ctx, w, sreq, info, row)
+		err := s.CompareModule(ctx, w, sreq, info, row)
+		// TODO: WorkState for CompareModule requests?
+		return nil, err
 	}
 
 	log.Infof(ctx, "running scanner.runScanModule: %s@%s", sreq.Path(), sreq.Version)
@@ -395,7 +394,10 @@
 		log.Infof(ctx, "scanner.runScanModule also storing imports vulns for %s: row.Vulns=%d", sreq.Path(), len(impRow.Vulns))
 		rows = append(rows, &impRow)
 	}
-	return writeResults(ctx, sreq.Serve, w, s.bqClient, govulncheck.TableName, rows)
+	if err := writeResults(ctx, sreq.Serve, w, s.bqClient, govulncheck.TableName, rows); err != nil {
+		return nil, err
+	}
+	return row.WorkState(), nil
 }
 
 // vulnsForMode returns vulns that make sense to report for
diff --git a/internal/worker/server.go b/internal/worker/server.go
index b398b12..ef695c2 100644
--- a/internal/worker/server.go
+++ b/internal/worker/server.go
@@ -20,6 +20,7 @@
 	"golang.org/x/pkgsite-metrics/internal/bigquery"
 	"golang.org/x/pkgsite-metrics/internal/config"
 	"golang.org/x/pkgsite-metrics/internal/derrors"
+	"golang.org/x/pkgsite-metrics/internal/fstore"
 	"golang.org/x/pkgsite-metrics/internal/govulncheck"
 	"golang.org/x/pkgsite-metrics/internal/jobs"
 	"golang.org/x/pkgsite-metrics/internal/log"
@@ -35,6 +36,8 @@
 	proxyClient *proxy.Client
 	queue       queue.Queue
 	jobDB       *jobs.DB
+	// Firestore namespace for storing work versions.
+	fsNamespace *fstore.Namespace
 
 	reqs int // for debugging
 
@@ -60,6 +63,12 @@
 		}
 	}
 
+	// Use the same name for the namespace as the BQ dataset.
+	ns, err := fstore.OpenNamespace(ctx, cfg.ProjectID, cfg.BigQueryDataset)
+	if err != nil {
+		return nil, err
+	}
+
 	q, err := queue.New(ctx, cfg,
 		func(ctx context.Context, t queue.Task) (int, error) {
 			// When running locally, only the module path and version are
@@ -95,6 +104,7 @@
 		proxyClient: proxyClient,
 		devMode:     cfg.DevMode,
 		jobDB:       jdb,
+		fsNamespace: ns,
 	}
 
 	if cfg.ProjectID != "" && cfg.ServiceID != "" {