internal/worker: read govulncheck work version per module@version

Instead of reading govulncheck work versions in a bunch for all modules,
which can lead to bq quota issues upon container start-up, we read work
versions for a module@version on scan request.

Change-Id: Id90336ec196b5cc4a394bb113af15b27dc2343db
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/502135
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Maceo Thompson <maceothompson@google.com>
Run-TryBot: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/internal/govulncheck/govulncheck.go b/internal/govulncheck/govulncheck.go
index 5c45547..01914cd 100644
--- a/internal/govulncheck/govulncheck.go
+++ b/internal/govulncheck/govulncheck.go
@@ -205,39 +205,24 @@
 	ErrorCategory string
 }
 
-// ReadWorkStates reads the most recent work versions in the govulncheck table
-// together with their accompanying error categories.
-func ReadWorkStates(ctx context.Context, c *bigquery.Client) (_ map[[2]string]*WorkState, err error) {
-	defer derrors.Wrap(&err, "ReadWorkStates")
+// ReadWorkState reads the most recent work version for module_path@version
+// in the govulncheck table together with its accompanying error category.
+func ReadWorkState(ctx context.Context, c *bigquery.Client, module_path, version string) (ws *WorkState, err error) {
+	defer derrors.Wrap(&err, "ReadWorkState")
 
-	// Preamble defines an auxiliary table that remembers the
-	// latest version, defined by sort_version, for each module.
-	const preamble = "WITH latest AS (SELECT module_path AS module, MAX(sort_version) as max_version FROM `%s` GROUP BY module_path)"
-	latest := fmt.Sprintf(preamble, c.FullTableName(TableName))
-	// Partition the table by module and version while only
-	// considering the `latest` version. This is accomplished
-	// by joining govulncheck table with latest.
-	partition := bigquery.PartitionQuery{
-		From:        fmt.Sprintf("`%s` JOIN latest ON module_path=module AND sort_version=max_version", c.FullTableName(TableName)),
-		Columns:     "module_path, version, go_version, worker_version, schema_version, vulndb_last_modified, error_category",
-		PartitionOn: "module_path, sort_version",
-		OrderBy:     "created_at DESC",
-	}.String()
-	// Create the final query that gets only one work version
-	// for each module. The returned work version is the latest
-	// one, defined by max of sort_version. Note that this will
-	// not match the latest version of a module, in the strict Go
-	// sense, if the module has non-linear tagging (which should
-	// not happen too often).
-	query := fmt.Sprintf("%s\n%s", latest, partition)
+	const qf = `
+                SELECT module_path, version, go_version, worker_version, schema_version, vulndb_last_modified, error_category
+                FROM %s WHERE module_path="%s" AND version="%s" ORDER BY created_at DESC LIMIT 1
+        `
+	query := fmt.Sprintf(qf, "`"+c.FullTableName(TableName)+"`", module_path, version)
 	iter, err := c.Query(ctx, query)
 	if err != nil {
 		return nil, err
 	}
 
-	m := map[[2]string]*WorkState{}
 	err = bigquery.ForEachRow(iter, func(r *Result) bool {
-		m[[2]string{r.ModulePath, r.Version}] = &WorkState{
+		// This should be reachable at most once.
+		ws = &WorkState{
 			WorkVersion:   &r.WorkVersion,
 			ErrorCategory: r.ErrorCategory,
 		}
@@ -246,7 +231,7 @@
 	if err != nil {
 		return nil, err
 	}
-	return m, nil
+	return ws, nil
 }
 
 // ScanStats contains monitoring information for a govulncheck run.
diff --git a/internal/govulncheck/govulncheck_test.go b/internal/govulncheck/govulncheck_test.go
index b9afc83..7b767ca 100644
--- a/internal/govulncheck/govulncheck_test.go
+++ b/internal/govulncheck/govulncheck_test.go
@@ -171,28 +171,24 @@
 		}
 	})
 	t.Run("work versions", func(t *testing.T) {
-		wss, err := ReadWorkStates(ctx, client)
+		ws, err := ReadWorkState(ctx, client, "m", "v")
 		if err != nil {
 			t.Fatal(err)
 		}
-		wsgot := wss[[2]string{"m", "v"}]
-		if wsgot == nil {
+		if ws == nil {
 			t.Fatal("got nil, wanted work state")
 		}
-		wgot := wsgot.WorkVersion
+		wgot := ws.WorkVersion
 		if wgot == nil {
 			t.Fatal("got nil, wanted work version")
 		}
 		if want := &row.WorkVersion; !wgot.Equal(want) {
 			t.Errorf("got %+v, want %+v", wgot, want)
 		}
-		egot := wsgot.ErrorCategory
+		egot := ws.ErrorCategory
 		if want := row.ErrorCategory; want != egot {
 			t.Errorf("got %+v, want %+v", egot, want)
 		}
-		if got := wss[[2]string{"m", "v2"}]; got != nil {
-			t.Errorf("got %v; want nil", got)
-		}
 	})
 }
 
diff --git a/internal/worker/govulncheck.go b/internal/worker/govulncheck.go
index 9d9d066..3f4edc9 100644
--- a/internal/worker/govulncheck.go
+++ b/internal/worker/govulncheck.go
@@ -7,7 +7,6 @@
 import (
 	"context"
 	"encoding/json"
-	"errors"
 	"os"
 	"path/filepath"
 	"time"
@@ -16,7 +15,6 @@
 	"golang.org/x/pkgsite-metrics/internal/derrors"
 	"golang.org/x/pkgsite-metrics/internal/govulncheck"
 	"golang.org/x/pkgsite-metrics/internal/log"
-	"google.golang.org/api/googleapi"
 )
 
 type GovulncheckServer struct {
@@ -25,37 +23,11 @@
 	workVersion      *govulncheck.WorkVersion
 }
 
-func newGovulncheckServer(ctx context.Context, s *Server) (*GovulncheckServer, error) {
-	var (
-		swv map[[2]string]*govulncheck.WorkState
-		err error
-	)
-	if s.bqClient != nil {
-		swv, err = govulncheck.ReadWorkStates(ctx, s.bqClient)
-		if err != nil {
-			if isReadWorkStatesQuotaError(err) {
-				log.Info(ctx, "hit bigquery list quota when reading work versions, sleeping 1 minute...")
-				// Sleep a minute to allow quota limitations
-				// to clear up.
-				time.Sleep(60 * time.Second)
-			}
-			return nil, err
-		}
-		log.Infof(ctx, "read %d work versions", len(swv))
-	}
+func newGovulncheckServer(s *Server) *GovulncheckServer {
 	return &GovulncheckServer{
 		Server:           s,
-		storedWorkStates: swv,
-	}, nil
-}
-
-func isReadWorkStatesQuotaError(err error) bool {
-	var gerr *googleapi.Error
-	if !errors.As(err, &gerr) {
-		return false
+		storedWorkStates: make(map[[2]string]*govulncheck.WorkState),
 	}
-	// BigQuery uses 403 for quota exceeded.
-	return gerr.Code == 403
 }
 
 func (h *GovulncheckServer) getWorkVersion(ctx context.Context) (_ *govulncheck.WorkVersion, err error) {
diff --git a/internal/worker/govulncheck_scan.go b/internal/worker/govulncheck_scan.go
index 1327c40..38e5fee 100644
--- a/internal/worker/govulncheck_scan.go
+++ b/internal/worker/govulncheck_scan.go
@@ -13,6 +13,7 @@
 	"path/filepath"
 	"regexp"
 	"strings"
+	"time"
 
 	"cloud.google.com/go/storage"
 	"golang.org/x/exp/event"
@@ -24,6 +25,7 @@
 	"golang.org/x/pkgsite-metrics/internal/proxy"
 	"golang.org/x/pkgsite-metrics/internal/sandbox"
 	"golang.org/x/pkgsite-metrics/internal/version"
+	"google.golang.org/api/googleapi"
 )
 
 const (
@@ -97,7 +99,7 @@
 }
 
 func (h *GovulncheckServer) canSkip(ctx context.Context, sreq *govulncheck.Request, scanner *scanner) (bool, error) {
-	if err := h.readGovulncheckWorkStates(ctx); err != nil {
+	if err := h.readGovulncheckWorkState(ctx, sreq.Module, sreq.Version); err != nil {
 		return false, err
 	}
 	wve := h.storedWorkStates[[2]string{sreq.Module, sreq.Version}]
@@ -110,9 +112,8 @@
 		// If the work version has not changed, skip analyzing the module
 		return true, nil
 	}
-	// Otherwise, skip if the error is not recoverable
-	// TODO: should we perhaps do this at enqueueall point
-	// as well? Would that introduce more savings?
+	// Otherwise, skip if the error is not recoverable. The version of the
+	// module has not changed, so we'll get the same error anyhow.
 	return unrecoverableError(wve.ErrorCategory), nil
 }
 
@@ -129,18 +130,39 @@
 	}
 }
 
-func (h *GovulncheckServer) readGovulncheckWorkStates(ctx context.Context) error {
+func (h *GovulncheckServer) readGovulncheckWorkState(ctx context.Context, module_path, version string) error {
 	h.mu.Lock()
 	defer h.mu.Unlock()
-	if h.storedWorkStates != nil {
+	// Don't read work state for module_path@version if an entry in the cache already exists.
+	if _, ok := h.storedWorkStates[[2]string{module_path, version}]; ok {
 		return nil
 	}
 	if h.bqClient == nil {
 		return nil
 	}
-	var err error
-	h.storedWorkStates, err = govulncheck.ReadWorkStates(ctx, h.bqClient)
-	return err
+	ws, err := govulncheck.ReadWorkState(ctx, h.bqClient, module_path, version)
+	if err != nil {
+		if isReadWorkStatesQuotaError(err) {
+			log.Info(ctx, "hit bigquery list quota when reading work version, sleeping 1 minute...")
+			// Sleep a minute to allow quota limitations to clear up.
+			time.Sleep(60 * time.Second)
+		}
+		return err
+	}
+	if ws != nil {
+		h.storedWorkStates[[2]string{module_path, version}] = ws
+	}
+	log.Infof(ctx, "read work version for %s@%s", module_path, version)
+	return nil
+}
+
+func isReadWorkStatesQuotaError(err error) bool {
+	var gerr *googleapi.Error
+	if !errors.As(err, &gerr) {
+		return false
+	}
+	// BigQuery uses 403 for quota exceeded.
+	return gerr.Code == 403
 }
 
 // A scanner holds state for scanning modules.
diff --git a/internal/worker/server.go b/internal/worker/server.go
index 33d842e..2de9762 100644
--- a/internal/worker/server.go
+++ b/internal/worker/server.go
@@ -115,9 +115,7 @@
 	if err := ensureTable(ctx, bq, govulncheck.TableName); err != nil {
 		return nil, err
 	}
-	if err := s.registerGovulncheckHandlers(ctx); err != nil {
-		return nil, err
-	}
+	s.registerGovulncheckHandlers()
 	if err := ensureTable(ctx, bq, analysis.TableName); err != nil {
 		return nil, err
 	}
@@ -183,16 +181,11 @@
 	http.Handle(pattern, s.observer.Observe(h))
 }
 
-func (s *Server) registerGovulncheckHandlers(ctx context.Context) error {
-	h, err := newGovulncheckServer(ctx, s)
-	if err != nil {
-		return err
-	}
-
+func (s *Server) registerGovulncheckHandlers() {
+	h := newGovulncheckServer(s)
 	s.handle("/govulncheck/enqueueall", h.handleEnqueueAll)
 	s.handle("/govulncheck/enqueue", h.handleEnqueue)
 	s.handle("/govulncheck/scan/", h.handleScan)
-	return nil
 }
 
 func (s *Server) registerAnalysisHandlers(ctx context.Context) error {