internal/worker: test vulncheck enqueue logic
Refactor the vulncheck enqueue handlers to test the logic that creates tasks.
Also fix a bug with parsing the query params: pass params instead of
¶ms to scan.ParseParams.
Change-Id: Id340958baf543d0732bcb889d180ce3a4f66ed5a
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/470255
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Reviewed-by: Zvonimir Pavlinovic <zpavlinovic@google.com>
diff --git a/internal/worker/testdata/modules.txt b/internal/worker/testdata/modules.txt
new file mode 100644
index 0000000..77010d2
--- /dev/null
+++ b/internal/worker/testdata/modules.txt
@@ -0,0 +1,4 @@
+std v1.19.4 2025760
+github.com/pkg/errors v0.9.1 10
+golang.org/x/net v0.4.0 20
+
diff --git a/internal/worker/vulncheck_enqueue.go b/internal/worker/vulncheck_enqueue.go
index b3ddb54..8302b98 100644
--- a/internal/worker/vulncheck_enqueue.go
+++ b/internal/worker/vulncheck_enqueue.go
@@ -6,11 +6,15 @@
import (
"context"
+ "errors"
"fmt"
"net/http"
+ "sort"
"strings"
"cloud.google.com/go/storage"
+ "golang.org/x/exp/maps"
+ "golang.org/x/pkgsite-metrics/internal/config"
"golang.org/x/pkgsite-metrics/internal/derrors"
"golang.org/x/pkgsite-metrics/internal/log"
"golang.org/x/pkgsite-metrics/internal/queue"
@@ -28,40 +32,74 @@
// handleEnqueue enqueues multiple modules for a single vulncheck mode.
func (h *VulncheckServer) handleEnqueue(w http.ResponseWriter, r *http.Request) error {
- params := &vulncheckEnqueueParams{Min: defaultMinImportedByCount}
- if err := scan.ParseParams(r, ¶ms); err != nil {
- return err
- }
- ctx := r.Context()
- mode, err := vulncheckMode(params.Mode)
+ return h.enqueue(r, false)
+}
+
+// handleEnqueueAll enqueues multiple modules for all vulncheck modes.
+func (h *VulncheckServer) handleEnqueueAll(w http.ResponseWriter, r *http.Request) error {
+ return h.enqueue(r, true)
+}
+
+func (h *VulncheckServer) enqueue(r *http.Request, allModes bool) error {
+ tasks, params, err := createVulncheckQueueTasks(r, h.cfg, allModes)
if err != nil {
return err
}
-
- var reqs []*vulncheckRequest
- if mode == ModeBinary {
- var err error
- reqs, err = readBinaries(ctx, h.cfg.BinaryBucket)
- if err != nil {
- return err
- }
- } else {
- modspecs, err := readModules(ctx, h.cfg, params.File, params.Min)
- if err != nil {
- return err
- }
- reqs = moduleSpecsToScanRequests(modspecs, mode)
- }
- var sreqs []queue.Task
- for _, req := range reqs {
- if req.Module != "std" { // ignore the standard library
- sreqs = append(sreqs, req)
- }
- }
- return enqueueTasks(ctx, sreqs, h.queue,
+ return enqueueTasks(r.Context(), tasks, h.queue,
&queue.Options{Namespace: "vulncheck", TaskNameSuffix: params.Suffix})
}
+func createVulncheckQueueTasks(r *http.Request, cfg *config.Config, allModes bool) (_ []queue.Task, _ *vulncheckEnqueueParams, err error) {
+ defer derrors.Wrap(&err, "createQueueTasks(%s, %t)", r.URL, allModes)
+ ctx := r.Context()
+ params := &vulncheckEnqueueParams{Min: defaultMinImportedByCount}
+ if err := scan.ParseParams(r, params); err != nil {
+ return nil, nil, err
+ }
+ if allModes && params.Mode != "" {
+ return nil, nil, errors.New("mode query param provided for enqueueAll")
+ }
+ var enqueueModes []string
+ if allModes {
+ enqueueModes = maps.Keys(modes)
+ sort.Strings(enqueueModes) // make deterministic for testing
+ } else {
+ mode, err := vulncheckMode(params.Mode)
+ if err != nil {
+ return nil, nil, err
+ }
+ enqueueModes = []string{mode}
+ }
+
+ var (
+ tasks []queue.Task
+ modspecs []scan.ModuleSpec
+ )
+ for _, mode := range enqueueModes {
+ var reqs []*vulncheckRequest
+ if mode == ModeBinary {
+ reqs, err = readBinaries(ctx, cfg.BinaryBucket)
+ if err != nil {
+ return nil, nil, err
+ }
+ } else {
+ if modspecs == nil {
+ modspecs, err = readModules(ctx, cfg, params.File, params.Min)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ reqs = moduleSpecsToScanRequests(modspecs, mode)
+ }
+ for _, req := range reqs {
+ if req.Module != "std" { // ignore the standard library
+ tasks = append(tasks, req)
+ }
+ }
+ }
+ return tasks, params, nil
+}
+
func vulncheckMode(mode string) (string, error) {
if mode == "" {
// VTA is the default mode
@@ -74,46 +112,6 @@
return mode, nil
}
-// handleEnqueueAll enqueues multiple modules for all vulncheck modes.
-// Query params:
-// - suffix: appended to task queue IDs to generate unique tasks
-// - file: path to file containing modules; if missing, use DB
-// - min: minimum import-by count for a module to be included
-func (h *VulncheckServer) handleEnqueueAll(w http.ResponseWriter, r *http.Request) error {
- params := &vulncheckEnqueueParams{Min: defaultMinImportedByCount}
- if err := scan.ParseParams(r, ¶ms); err != nil {
- return err
- }
-
- ctx := r.Context()
- modspecs, err := readModules(ctx, h.cfg, params.File, params.Min)
- if err != nil {
- return err
- }
- opts := &queue.Options{Namespace: "vulncheck", TaskNameSuffix: params.Suffix}
- for mode := range modes {
- var reqs []*vulncheckRequest
- if mode == ModeBinary {
- reqs, err = readBinaries(ctx, h.cfg.BinaryBucket)
- if err != nil {
- return err
- }
- } else {
- reqs = moduleSpecsToScanRequests(modspecs, mode)
- }
- var tasks []queue.Task
- for _, req := range reqs {
- if req.Module != "std" { // ignore the standard library
- tasks = append(tasks, req)
- }
- }
- if err := enqueueTasks(ctx, tasks, h.queue, opts); err != nil {
- return err
- }
- }
- return nil
-}
-
// binaryDir is the directory in the GCS bucket that contains binaries that should be scanned.
const binaryDir = "binaries"
diff --git a/internal/worker/vulncheck_enqueue_test.go b/internal/worker/vulncheck_enqueue_test.go
index 931e9df..8707d06 100644
--- a/internal/worker/vulncheck_enqueue_test.go
+++ b/internal/worker/vulncheck_enqueue_test.go
@@ -7,8 +7,12 @@
import (
"context"
"flag"
+ "net/http"
"testing"
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/pkgsite-metrics/internal/config"
+ "golang.org/x/pkgsite-metrics/internal/queue"
"golang.org/x/pkgsite-metrics/internal/scan"
)
@@ -44,3 +48,45 @@
}
}
}
+
+func TestCreateQueueTasks(t *testing.T) {
+ vreq := func(path, version, mode string, importedBy int) *vulncheckRequest {
+ return &vulncheckRequest{
+ scan.ModuleURLPath{Module: path, Version: version},
+ vulncheckRequestParams{Mode: mode, ImportedBy: importedBy},
+ }
+ }
+
+ req, err := http.NewRequest("GET", "https://path?min=8&file=testdata/modules.txt", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotTasks, _, err := createVulncheckQueueTasks(req, &config.Config{}, false)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wantTasks := []queue.Task{
+ vreq("github.com/pkg/errors", "v0.9.1", ModeVTA, 10),
+ vreq("golang.org/x/net", "v0.4.0", ModeVTA, 20),
+ }
+ if diff := cmp.Diff(wantTasks, gotTasks, cmp.AllowUnexported(vulncheckRequest{})); diff != "" {
+ t.Errorf("mismatch (-want, +got):\n%s", diff)
+ }
+
+ gotTasks, _, err = createVulncheckQueueTasks(req, &config.Config{}, true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTasks = nil
+ // cfg.BinaryBucket is empty, so no binary-mode tasks are created.
+ for _, mode := range []string{ModeImports, ModeVTA, ModeVTAStacks} {
+ wantTasks = append(wantTasks,
+ vreq("github.com/pkg/errors", "v0.9.1", mode, 10),
+ vreq("golang.org/x/net", "v0.4.0", mode, 20))
+ }
+
+ if diff := cmp.Diff(wantTasks, gotTasks, cmp.AllowUnexported(vulncheckRequest{})); diff != "" {
+ t.Errorf("mismatch (-want, +got):\n%s", diff)
+ }
+}