internal/worker: refactor vulncheck enqueue logic
Minor cleanup.
Change-Id: Id5ad8690effee63839639ffbaac674c2705b81d7
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/473175
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/internal/worker/vulncheck_enqueue.go b/internal/worker/vulncheck_enqueue.go
index 8302b98..5e9f0fd 100644
--- a/internal/worker/vulncheck_enqueue.go
+++ b/internal/worker/vulncheck_enqueue.go
@@ -41,52 +41,57 @@
}
func (h *VulncheckServer) enqueue(r *http.Request, allModes bool) error {
- tasks, params, err := createVulncheckQueueTasks(r, h.cfg, allModes)
- if err != nil {
- return err
- }
- return enqueueTasks(r.Context(), tasks, h.queue,
- &queue.Options{Namespace: "vulncheck", TaskNameSuffix: params.Suffix})
-}
-
-func createVulncheckQueueTasks(r *http.Request, cfg *config.Config, allModes bool) (_ []queue.Task, _ *vulncheckEnqueueParams, err error) {
- defer derrors.Wrap(&err, "createQueueTasks(%s, %t)", r.URL, allModes)
ctx := r.Context()
params := &vulncheckEnqueueParams{Min: defaultMinImportedByCount}
if err := scan.ParseParams(r, params); err != nil {
- return nil, nil, err
+ return fmt.Errorf("%w: %v", derrors.InvalidArgument, err)
}
- if allModes && params.Mode != "" {
- return nil, nil, errors.New("mode query param provided for enqueueAll")
+ modes, err := listModes(params.Mode, allModes)
+ if err != nil {
+ return fmt.Errorf("%w: %v", derrors.InvalidArgument, err)
}
- var enqueueModes []string
- if allModes {
- enqueueModes = maps.Keys(modes)
- sort.Strings(enqueueModes) // make deterministic for testing
- } else {
- mode, err := vulncheckMode(params.Mode)
- if err != nil {
- return nil, nil, err
- }
- enqueueModes = []string{mode}
+ tasks, err := createVulncheckQueueTasks(ctx, h.cfg, params, modes)
+ if err != nil {
+ return err
}
+ return enqueueTasks(ctx, tasks, h.queue,
+ &queue.Options{Namespace: "vulncheck", TaskNameSuffix: params.Suffix})
+}
+func listModes(modeParam string, allModes bool) ([]string, error) {
+ if allModes {
+ if modeParam != "" {
+ return nil, errors.New("mode query param provided for enqueueAll")
+ }
+ ms := maps.Keys(modes)
+ sort.Strings(ms) // make deterministic for testing
+ return ms, nil
+ }
+ mode, err := vulncheckMode(modeParam)
+ if err != nil {
+ return nil, err
+ }
+ return []string{mode}, nil
+}
+
+func createVulncheckQueueTasks(ctx context.Context, cfg *config.Config, params *vulncheckEnqueueParams, modes []string) (_ []queue.Task, err error) {
+ defer derrors.Wrap(&err, "createVulncheckQueueTasks(%v)", modes)
var (
tasks []queue.Task
modspecs []scan.ModuleSpec
)
- for _, mode := range enqueueModes {
+ for _, mode := range modes {
var reqs []*vulncheckRequest
if mode == ModeBinary {
reqs, err = readBinaries(ctx, cfg.BinaryBucket)
if err != nil {
- return nil, nil, err
+ return nil, err
}
} else {
if modspecs == nil {
modspecs, err = readModules(ctx, cfg, params.File, params.Min)
if err != nil {
- return nil, nil, err
+ return nil, err
}
}
reqs = moduleSpecsToScanRequests(modspecs, mode)
@@ -97,7 +102,7 @@
}
}
}
- return tasks, params, nil
+ return tasks, nil
}
func vulncheckMode(mode string) (string, error) {
diff --git a/internal/worker/vulncheck_enqueue_test.go b/internal/worker/vulncheck_enqueue_test.go
index ad17ea3..9ecc0fe 100644
--- a/internal/worker/vulncheck_enqueue_test.go
+++ b/internal/worker/vulncheck_enqueue_test.go
@@ -7,7 +7,7 @@
import (
"context"
"flag"
- "net/http"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -57,11 +57,8 @@
}
}
- req, err := http.NewRequest("GET", "https://path?min=8&file=testdata/modules.txt", nil)
- if err != nil {
- t.Fatal(err)
- }
- gotTasks, _, err := createVulncheckQueueTasks(req, &config.Config{}, false)
+ params := &vulncheckEnqueueParams{Min: 8, File: "testdata/modules.txt"}
+ gotTasks, err := createVulncheckQueueTasks(context.Background(), &config.Config{}, params, []string{ModeVTA})
if err != nil {
t.Fatal(err)
}
@@ -74,7 +71,11 @@
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
- gotTasks, _, err = createVulncheckQueueTasks(req, &config.Config{}, true)
+ allModes, err := listModes("", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotTasks, err = createVulncheckQueueTasks(context.Background(), &config.Config{}, params, allModes)
if err != nil {
t.Fatal(err)
}
@@ -90,3 +91,27 @@
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}
+
+func TestListModes(t *testing.T) {
+ for _, test := range []struct {
+ param string
+ all bool
+ want []string
+ wantErr bool
+ }{
+ {"", true, []string{ModeBinary, ModeGovulncheck, ModeImports, ModeVTA, ModeVTAStacks}, false},
+ {"", false, []string{ModeVTA}, false},
+ {"imports", false, []string{ModeImports}, false},
+ {"imports", true, nil, true},
+ } {
+ t.Run(fmt.Sprintf("%q,%t", test.param, test.all), func(t *testing.T) {
+ got, err := listModes(test.param, test.all)
+ if err != nil && !test.wantErr {
+ t.Fatal(err)
+ }
+ if err == nil && !cmp.Equal(got, test.want) {
+ t.Errorf("got %v, want %v", got, test.want)
+ }
+ })
+ }
+}