internal/worker: add force option to ScanModules

Normally, a module isn't scanned if there is a record of the scan
in the DB. Providing a "force" query parameter or command-line flag
overrides that.

Change-Id: I1d8836cd011060feb0ef2cf33e033abb2dbe9e67
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/393835
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/cmd/worker/main.go b/cmd/worker/main.go
index 198ff38..40c8658 100644
--- a/cmd/worker/main.go
+++ b/cmd/worker/main.go
@@ -34,7 +34,7 @@
 var (
 	// Flags only for the command-line tool.
 	localRepoPath = flag.String("local-cve-repo", "", "path to local repo, instead of cloning remote")
-	force         = flag.Bool("force", false, "force an update to happen")
+	force         = flag.Bool("force", false, "force an update or scan to happen")
 	limit         = flag.Int("limit", 0,
 		"limit on number of things to list or issues to create (0 means unlimited)")
 	githubTokenFile = flag.String("ghtokenfile", "",
@@ -281,7 +281,7 @@
 }
 
 func scanModulesCommand(ctx context.Context) error {
-	return worker.ScanModules(ctx, cfg.Store)
+	return worker.ScanModules(ctx, cfg.Store, *force)
 }
 
 func die(format string, args ...interface{}) {
diff --git a/internal/worker/scan_modules.go b/internal/worker/scan_modules.go
index e3ed344..79ca291 100644
--- a/internal/worker/scan_modules.go
+++ b/internal/worker/scan_modules.go
@@ -38,7 +38,7 @@
 
 // ScanModules scans a list of Go modules for vulnerabilities.
 // It assumes the root of each repo is a module, and there are no nested modules.
-func ScanModules(ctx context.Context, st store.Store) error {
+func ScanModules(ctx context.Context, st store.Store, force bool) error {
 	dbClient, err := vulnc.NewClient([]string{vulnDBURL}, vulnc.Options{})
 	if err != nil {
 		return err
@@ -49,7 +49,7 @@
 		if err != nil {
 			return err
 		}
-		if err := processModule(ctx, modulePath, latest, dbClient, st); err != nil {
+		if err := processModule(ctx, modulePath, latest, dbClient, st, force); err != nil {
 			return err
 		}
 		latestTagged, err := latestTaggedVersion(ctx, modulePath)
@@ -57,7 +57,7 @@
 			return err
 		}
 		if latestTagged != "" && latestTagged != latest {
-			if err := processModule(ctx, modulePath, latestTagged, dbClient, st); err != nil {
+			if err := processModule(ctx, modulePath, latestTagged, dbClient, st, force); err != nil {
 				return err
 			}
 		}
@@ -65,23 +65,24 @@
 	return nil
 }
 
-func processModule(ctx context.Context, modulePath, version string, dbClient vulnc.Client, st store.Store) (err error) {
+func processModule(ctx context.Context, modulePath, version string, dbClient vulnc.Client, st store.Store, force bool) (err error) {
 	defer derrors.Wrap(&err, "processModule(%q, %q)", modulePath, version)
 
 	dbTime, err := vulnDBTime(ctx)
 	if err != nil {
 		return err
 	}
-	r, err := st.GetModuleScanRecord(ctx, modulePath, version, dbTime)
-	if err != nil {
-		return err
+	if !force {
+		r, err := st.GetModuleScanRecord(ctx, modulePath, version, dbTime)
+		if err != nil {
+			return err
+		}
+		if r != nil {
+			// Already done.
+			log.Debugf(ctx, "already scanned %s@%s at DB time %s", modulePath, version, dbTime)
+			return nil
+		}
 	}
-	if r != nil {
-		// Already done.
-		log.Debugf(ctx, "already scanned %s@%s at DB time %s", modulePath, version, dbTime)
-		return nil
-	}
-
 	res, err := scanModule(ctx, modulePath, version, dbClient)
 	if err2 := createModuleScanRecord(ctx, st, modulePath, version, dbTime, res, err); err2 != nil {
 		return err2
diff --git a/internal/worker/scan_modules_test.go b/internal/worker/scan_modules_test.go
index fdeb4ae..82296c7 100644
--- a/internal/worker/scan_modules_test.go
+++ b/internal/worker/scan_modules_test.go
@@ -26,7 +26,7 @@
 	// Verify only that scanModules works (doesn't return an error).
 	ctx := event.WithExporter(context.Background(),
 		event.NewExporter(log.NewLineHandler(os.Stderr), nil))
-	if err := ScanModules(ctx, store.NewMemStore()); err != nil {
+	if err := ScanModules(ctx, store.NewMemStore(), true); err != nil {
 		t.Fatal(err)
 	}
 }
diff --git a/internal/worker/server.go b/internal/worker/server.go
index 5370300..eda39c2 100644
--- a/internal/worker/server.go
+++ b/internal/worker/server.go
@@ -300,10 +300,7 @@
 			err:    fmt.Errorf("%s required", http.MethodPost),
 		}
 	}
-	force := false
-	if f := r.FormValue("force"); f == "true" {
-		force = true
-	}
+	force := (r.FormValue("force") == "true")
 	err = UpdateCVEsAtCommit(r.Context(), cvelistrepo.URL, "HEAD", s.cfg.Store, pkgsiteURL, force)
 	if cerr := new(CheckUpdateError); errors.As(err, &cerr) {
 		return &serverError{
@@ -375,7 +372,7 @@
 }
 
 func (s *Server) handleScanModules(w http.ResponseWriter, r *http.Request) error {
-	return ScanModules(r.Context(), s.cfg.Store)
+	return ScanModules(r.Context(), s.cfg.Store, r.FormValue("force") == "true")
 }
 
 func initOpenTelemetry(projectID string) (tp *sdktrace.TracerProvider, mp metric.MeterProvider, err error) {