internal/worker: refactor vulncheck enqueue logic

Minor cleanup.

Change-Id: Id5ad8690effee63839639ffbaac674c2705b81d7
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/473175
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/internal/worker/vulncheck_enqueue.go b/internal/worker/vulncheck_enqueue.go
index 8302b98..5e9f0fd 100644
--- a/internal/worker/vulncheck_enqueue.go
+++ b/internal/worker/vulncheck_enqueue.go
@@ -41,52 +41,57 @@
 }
 
 func (h *VulncheckServer) enqueue(r *http.Request, allModes bool) error {
-	tasks, params, err := createVulncheckQueueTasks(r, h.cfg, allModes)
-	if err != nil {
-		return err
-	}
-	return enqueueTasks(r.Context(), tasks, h.queue,
-		&queue.Options{Namespace: "vulncheck", TaskNameSuffix: params.Suffix})
-}
-
-func createVulncheckQueueTasks(r *http.Request, cfg *config.Config, allModes bool) (_ []queue.Task, _ *vulncheckEnqueueParams, err error) {
-	defer derrors.Wrap(&err, "createQueueTasks(%s, %t)", r.URL, allModes)
 	ctx := r.Context()
 	params := &vulncheckEnqueueParams{Min: defaultMinImportedByCount}
 	if err := scan.ParseParams(r, params); err != nil {
-		return nil, nil, err
+		return fmt.Errorf("%w: %v", derrors.InvalidArgument, err)
 	}
-	if allModes && params.Mode != "" {
-		return nil, nil, errors.New("mode query param provided for enqueueAll")
+	modes, err := listModes(params.Mode, allModes)
+	if err != nil {
+		return fmt.Errorf("%w: %v", derrors.InvalidArgument, err)
 	}
-	var enqueueModes []string
-	if allModes {
-		enqueueModes = maps.Keys(modes)
-		sort.Strings(enqueueModes) // make deterministic for testing
-	} else {
-		mode, err := vulncheckMode(params.Mode)
-		if err != nil {
-			return nil, nil, err
-		}
-		enqueueModes = []string{mode}
+	tasks, err := createVulncheckQueueTasks(ctx, h.cfg, params, modes)
+	if err != nil {
+		return err
 	}
+	return enqueueTasks(ctx, tasks, h.queue,
+		&queue.Options{Namespace: "vulncheck", TaskNameSuffix: params.Suffix})
+}
 
+func listModes(modeParam string, allModes bool) ([]string, error) {
+	if allModes {
+		if modeParam != "" {
+			return nil, errors.New("mode query param provided for enqueueAll")
+		}
+		ms := maps.Keys(modes)
+		sort.Strings(ms) // make deterministic for testing
+		return ms, nil
+	}
+	mode, err := vulncheckMode(modeParam)
+	if err != nil {
+		return nil, err
+	}
+	return []string{mode}, nil
+}
+
+func createVulncheckQueueTasks(ctx context.Context, cfg *config.Config, params *vulncheckEnqueueParams, modes []string) (_ []queue.Task, err error) {
+	defer derrors.Wrap(&err, "createVulncheckQueueTasks(%v)", modes)
 	var (
 		tasks    []queue.Task
 		modspecs []scan.ModuleSpec
 	)
-	for _, mode := range enqueueModes {
+	for _, mode := range modes {
 		var reqs []*vulncheckRequest
 		if mode == ModeBinary {
 			reqs, err = readBinaries(ctx, cfg.BinaryBucket)
 			if err != nil {
-				return nil, nil, err
+				return nil, err
 			}
 		} else {
 			if modspecs == nil {
 				modspecs, err = readModules(ctx, cfg, params.File, params.Min)
 				if err != nil {
-					return nil, nil, err
+					return nil, err
 				}
 			}
 			reqs = moduleSpecsToScanRequests(modspecs, mode)
@@ -97,7 +102,7 @@
 			}
 		}
 	}
-	return tasks, params, nil
+	return tasks, nil
 }
 
 func vulncheckMode(mode string) (string, error) {
diff --git a/internal/worker/vulncheck_enqueue_test.go b/internal/worker/vulncheck_enqueue_test.go
index ad17ea3..9ecc0fe 100644
--- a/internal/worker/vulncheck_enqueue_test.go
+++ b/internal/worker/vulncheck_enqueue_test.go
@@ -7,7 +7,7 @@
 import (
 	"context"
 	"flag"
-	"net/http"
+	"fmt"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -57,11 +57,8 @@
 		}
 	}
 
-	req, err := http.NewRequest("GET", "https://path?min=8&file=testdata/modules.txt", nil)
-	if err != nil {
-		t.Fatal(err)
-	}
-	gotTasks, _, err := createVulncheckQueueTasks(req, &config.Config{}, false)
+	params := &vulncheckEnqueueParams{Min: 8, File: "testdata/modules.txt"}
+	gotTasks, err := createVulncheckQueueTasks(context.Background(), &config.Config{}, params, []string{ModeVTA})
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -74,7 +71,11 @@
 		t.Errorf("mismatch (-want, +got):\n%s", diff)
 	}
 
-	gotTasks, _, err = createVulncheckQueueTasks(req, &config.Config{}, true)
+	allModes, err := listModes("", true)
+	if err != nil {
+		t.Fatal(err)
+	}
+	gotTasks, err = createVulncheckQueueTasks(context.Background(), &config.Config{}, params, allModes)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -90,3 +91,27 @@
 		t.Errorf("mismatch (-want, +got):\n%s", diff)
 	}
 }
+
+func TestListModes(t *testing.T) {
+	for _, test := range []struct {
+		param   string
+		all     bool
+		want    []string
+		wantErr bool
+	}{
+		{"", true, []string{ModeBinary, ModeGovulncheck, ModeImports, ModeVTA, ModeVTAStacks}, false},
+		{"", false, []string{ModeVTA}, false},
+		{"imports", false, []string{ModeImports}, false},
+		{"imports", true, nil, true},
+	} {
+		t.Run(fmt.Sprintf("%q,%t", test.param, test.all), func(t *testing.T) {
+			got, err := listModes(test.param, test.all)
+			if err != nil && !test.wantErr {
+				t.Fatal(err)
+			}
+			if err == nil && !cmp.Equal(got, test.want) {
+				t.Errorf("got %v, want %v", got, test.want)
+			}
+		})
+	}
+}