internal/queue: generalize
Break the dependency between internal/queue and internal/scan.
Queue.Enqueue takes a Task, which is a new interface that scan.Request
implements.
Change-Id: I5c0bebc556caf94cf497ad59de2290efec7eec70
Reviewed-on: https://go-review.googlesource.com/c/pkgsite-metrics/+/467315
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julieqiu@google.com>
diff --git a/internal/queue/queue.go b/internal/queue/queue.go
index 4c0f6be..ffd50f2 100644
--- a/internal/queue/queue.go
+++ b/internal/queue/queue.go
@@ -8,11 +8,11 @@
import (
"context"
+ "crypto/sha256"
+ "encoding/hex"
"errors"
"fmt"
- "hash/fnv"
"io"
- "math"
"strings"
"time"
@@ -20,18 +20,24 @@
"golang.org/x/pkgsite-metrics/internal/config"
"golang.org/x/pkgsite-metrics/internal/derrors"
"golang.org/x/pkgsite-metrics/internal/log"
- "golang.org/x/pkgsite-metrics/internal/scan"
taskspb "google.golang.org/genproto/googleapis/cloud/tasks/v2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
)
+// A Task can produce information needed for Cloud Tasks.
+type Task interface {
+ Name() string // Human-readable string for the task. Need not be unique.
+ Path() string // URL path
+ Params() string // URL query params
+}
+
// A Queue provides an interface for asynchronous scheduling of fetch actions.
type Queue interface {
// Enqueue a scan request.
// Reports whether a new task was actually added.
- EnqueueScan(context.Context, *scan.Request, *Options) (bool, error)
+ EnqueueScan(context.Context, Task, *Options) (bool, error)
}
// New creates a new Queue with name queueName based on the configuration
@@ -97,11 +103,12 @@
}, nil
}
-// EnqeueuScan enqueues a task on GCP to fetch the given modulePath and
-// version. It returns an error if there was an error hashing the task name, or
-// an error pushing the task to GCP. If the task was a duplicate, it returns (false, nil).
-func (q *GCP) EnqueueScan(ctx context.Context, sreq *scan.Request, opts *Options) (enqueued bool, err error) {
- defer derrors.WrapStack(&err, "queue.EnqueueScan(%v, %v)", sreq, opts)
+// Enqueue enqueues a task on GCP.
+// It returns an error if there was an error hashing the task name, or
+// an error pushing the task to GCP.
+// If the task was a duplicate, it returns (false, nil).
+func (q *GCP) EnqueueScan(ctx context.Context, task Task, opts *Options) (enqueued bool, err error) {
+ defer derrors.WrapStack(&err, "queue.EnqueueScan(%s, %s, %v)", task.Path(), task.Params(), opts)
if opts == nil {
opts = &Options{}
}
@@ -112,15 +119,15 @@
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- req, err := q.newTaskRequest(sreq, opts)
+ req, err := q.newTaskRequest(task, opts)
if err != nil {
- return false, fmt.Errorf("q.newTaskRequest(modulePath, version, importedBy, opts): %v", err)
+ return false, fmt.Errorf("newTaskRequest: %v", err)
}
enqueued = true
if _, err := q.client.CreateTask(ctx, req); err != nil {
if status.Code(err) == codes.AlreadyExists {
- log.Debugf(ctx, "ignoring duplicate task ID %s: %s@%s", req.Task.Name, sreq.Module, sreq.Version)
+ log.Debugf(ctx, "ignoring duplicate task ID %s", req.Task.Name)
enqueued = false
} else {
return false, fmt.Errorf("q.client.CreateTask(ctx, req): %v", err)
@@ -146,67 +153,69 @@
// See https://cloud.google.com/tasks/docs/creating-http-target-tasks.
const maxCloudTasksTimeout = 30 * time.Minute
-const (
- DisableProxyFetchParam = "proxyfetch"
- DisableProxyFetchValue = "off"
-)
+const disableProxyFetchParam = "proxyfetch=off"
-func (q *GCP) newTaskRequest(sreq *scan.Request, opts *Options) (_ *taskspb.CreateTaskRequest, err error) {
- defer derrors.Wrap(&err, "newTaskRequest(%v, %v)", sreq, opts)
-
- if sreq.Mode == "" {
- return nil, errors.New("ScanRequest.Mode cannot be empty")
- }
+func (q *GCP) newTaskRequest(task Task, opts *Options) (*taskspb.CreateTaskRequest, error) {
if opts.Namespace == "" {
return nil, errors.New("Options.Namespace cannot be empty")
}
- taskID := newTaskID(sreq.Module, sreq.Version)
- relativeURI := fmt.Sprintf("/%s/scan/%s", opts.Namespace, sreq.URLPathAndParams())
- var params []string
+ relativeURI := fmt.Sprintf("/%s/scan/%s", opts.Namespace, task.Path())
+ params := task.Params()
if opts.DisableProxyFetch {
- params = append(params, fmt.Sprintf("%s=%s", DisableProxyFetchParam, DisableProxyFetchValue))
+ if params == "" {
+ params = disableProxyFetchParam
+ } else {
+ params += "&" + disableProxyFetchParam
+ }
}
- if len(params) > 0 {
- relativeURI += fmt.Sprintf("?%s", strings.Join(params, "&"))
+ if params != "" {
+ relativeURI += "?" + params
}
- task := &taskspb.Task{
+ taskID := newTaskID(opts.Namespace, task)
+ taskpb := &taskspb.Task{
Name: fmt.Sprintf("%s/tasks/%s", q.queueName, taskID),
DispatchDeadline: durationpb.New(maxCloudTasksTimeout),
- }
- task.MessageType = &taskspb.Task_HttpRequest{
- HttpRequest: &taskspb.HttpRequest{
- HttpMethod: taskspb.HttpMethod_POST,
- Url: q.queueURL + relativeURI,
- AuthorizationHeader: q.token,
+ MessageType: &taskspb.Task_HttpRequest{
+ HttpRequest: &taskspb.HttpRequest{
+ HttpMethod: taskspb.HttpMethod_POST,
+ Url: q.queueURL + relativeURI,
+ AuthorizationHeader: q.token,
+ },
},
}
req := &taskspb.CreateTaskRequest{
Parent: q.queueName,
- Task: task,
+ Task: taskpb,
}
- // If suffix is non-empty, append it to the task name. The same goes for mode.
+ // If suffix is non-empty, append it to the task name.
// This lets us force reprocessing of tasks that would normally be de-duplicated.
if opts.TaskNameSuffix != "" {
req.Task.Name += "-" + opts.TaskNameSuffix
}
- req.Task.Name += "-" + sreq.Mode
return req, nil
}
-// Create a task ID for the given module path and version.
+// Create a task ID for the given task.
+// Tasks with the same ID that are created within a few hours of each other. will be de-duplicated.
+// See https://cloud.google.com/tasks/docs/reference/rpc/google.cloud.tasks.v2#createtaskrequest
+// under "Task De-duplication".
+func newTaskID(namespace string, task Task) string {
+ name := task.Name()
+ // Hash the path and params of the task.
+ hasher := sha256.New()
+ io.WriteString(hasher, task.Path())
+ io.WriteString(hasher, task.Params())
+ hash := hex.EncodeToString(hasher.Sum(nil))
+ return escapeTaskID(fmt.Sprintf("%s-%s-%s", name, namespace, hash[:8]))
+}
+
+// escapeTaskIDs escapes s so it contains only valid characters for a Cloud Tasks name.
+// It tries to produce a readable result.
// Task IDs can contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-), or underscores (_).
-func newTaskID(modulePath, version string) string {
- mv := modulePath + "@" + version
- // Compute a hash to use as a prefix, so the task IDs are distributed uniformly.
- // See https://cloud.google.com/tasks/docs/reference/rpc/google.cloud.tasks.v2#task
- // under "Task De-duplication".
- hasher := fnv.New32()
- io.WriteString(hasher, mv)
- hash := hasher.Sum32() % math.MaxUint16
- // Escape the name so it contains only valid characters. Do our best to make it readable.
+func escapeTaskID(s string) string {
var b strings.Builder
- for _, r := range mv {
+ for _, r := range s {
switch {
case r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-':
b.WriteRune(r)
@@ -215,14 +224,14 @@
case r == '/':
b.WriteString("_-")
case r == '@':
- b.WriteString("_v")
+ b.WriteString("_")
case r == '.':
- b.WriteString("_o")
+ b.WriteString("_")
default:
fmt.Fprintf(&b, "_%04x", r)
}
}
- return fmt.Sprintf("%04x-%s", hash, &b)
+ return b.String()
}
// InMemory is a Queue implementation that schedules in-process fetch
@@ -231,18 +240,18 @@
//
// This should only be used for local development.
type InMemory struct {
- queue chan *scan.Request
+ queue chan Task
done chan struct{}
}
-type inMemoryProcessFunc func(context.Context, *scan.Request) (int, error)
+type inMemoryProcessFunc func(context.Context, Task) (int, error)
// NewInMemory creates a new InMemory that asynchronously fetches
// from proxyClient and stores in db. It uses workerCount parallelism to
// execute these fetches.
func NewInMemory(ctx context.Context, workerCount int, processFunc inMemoryProcessFunc) *InMemory {
q := &InMemory{
- queue: make(chan *scan.Request, 1000),
+ queue: make(chan Task, 1000),
done: make(chan struct{}),
}
sem := make(chan struct{}, workerCount)
@@ -256,16 +265,16 @@
// If a worker is available, make a request to the fetch service inside a
// goroutine and wait for it to finish.
- go func(r *scan.Request) {
+ go func(t Task) {
defer func() { <-sem }()
- log.Infof(ctx, "Fetch requested: %v (workerCount = %d)", r, cap(sem))
+ log.Infof(ctx, "Fetch requested: %v (workerCount = %d)", t, cap(sem))
fetchCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
- if _, err := processFunc(fetchCtx, r); err != nil {
- log.Errorf(fetchCtx, "processFunc(%q, %q): %v", r.Path, r.Version, err)
+ if _, err := processFunc(fetchCtx, t); err != nil {
+ log.Errorf(fetchCtx, "processFunc(%v, %q): %v", t, err)
}
}(v)
}
@@ -281,10 +290,10 @@
return q
}
-// EnqeueuScan pushes a fetch task into the local queue to be processed
+// Enqueue pushes a fetch task into the local queue to be processed
// asynchronously.
-func (q *InMemory) EnqueueScan(ctx context.Context, req *scan.Request, _ *Options) (bool, error) {
- q.queue <- req
+func (q *InMemory) EnqueueScan(ctx context.Context, task Task, _ *Options) (bool, error) {
+ q.queue <- task
return true, nil
}
diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go
index be5ec81..5b91c65 100644
--- a/internal/queue/queue_test.go
+++ b/internal/queue/queue_test.go
@@ -11,22 +11,38 @@
"golang.org/x/pkgsite-metrics/internal/config"
"golang.org/x/pkgsite-metrics/internal/scan"
taskspb "google.golang.org/genproto/googleapis/cloud/tasks/v2"
- "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/durationpb"
)
+type testTask struct {
+ name string
+ path string
+ params string
+}
+
+func (t *testTask) Name() string { return t.name }
+func (t *testTask) Path() string { return t.path }
+func (t *testTask) Params() string { return t.params }
+
func TestNewTaskID(t *testing.T) {
for _, test := range []struct {
- modulePath, version string
- want string
+ name, path, params string
+ want string
}{
- {"m-1", "v2", "acc5-m-1_vv2"},
- {"my_module", "v1.2.3", "0cb9-my__module_vv1_o2_o3"},
- {"µπΩ/github.com", "v2.3.4-ß", "a49c-_00b5_03c0_03a9_-github_ocom_vv2_o3_o4-_00df"},
+ {
+ "m@v1.2", "path", "params",
+ "m_v1_2-ns-31026413",
+ },
+ {
+ "µπΩ/github.com@v2.3.4-ß", "p", "",
+ "_00b5_03c0_03a9_-github_com_v2_3_4-_00df-ns-148de9c5",
+ },
} {
- got := newTaskID(test.modulePath, test.version)
+ tt := &testTask{test.name, test.path, test.params}
+ got := newTaskID("ns", tt)
if got != test.want {
- t.Errorf("%s@%s: got %s, want %s", test.modulePath, test.version, got, test.want)
+ t.Errorf("%v: got %s, want %s", tt, got, test.want)
}
}
}
@@ -45,7 +61,7 @@
MessageType: &taskspb.Task_HttpRequest{
HttpRequest: &taskspb.HttpRequest{
HttpMethod: taskspb.HttpMethod_POST,
- Url: "http://1.2.3.4:8000/test/scan/mod/@v/v1.2.3?importedby=0&mode=test&insecure=true",
+ Url: "http://1.2.3.4:8000/test/scan/mod@v1.2.3?importedby=0&mode=test&insecure=true",
AuthorizationHeader: &taskspb.HttpRequest_OidcToken{
OidcToken: &taskspb.OidcToken{
ServiceAccountEmail: "sa",
@@ -79,18 +95,18 @@
t.Fatal(err)
}
want.Task.Name = got.Task.Name
- if diff := cmp.Diff(want, got, cmp.Comparer(proto.Equal)); diff != "" {
+ if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
opts.DisableProxyFetch = true
- want.Task.MessageType.(*taskspb.Task_HttpRequest).HttpRequest.Url += "?proxyfetch=off"
+ want.Task.MessageType.(*taskspb.Task_HttpRequest).HttpRequest.Url += "&proxyfetch=off"
got, err = gcp.newTaskRequest(sreq, opts)
if err != nil {
t.Fatal(err)
}
want.Task.Name = got.Task.Name
- if diff := cmp.Diff(want, got, cmp.Comparer(proto.Equal)); diff != "" {
+ if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
diff --git a/internal/scan/parse.go b/internal/scan/parse.go
index 76aeaaa..8b27e73 100644
--- a/internal/scan/parse.go
+++ b/internal/scan/parse.go
@@ -33,13 +33,8 @@
Insecure bool
}
-func (r *Request) URLPathAndParams() string {
- suf := r.Suffix
- if suf != "" {
- suf = "/" + suf
- }
- return fmt.Sprintf("%s/@v/%s%s?importedby=%d&mode=%s&insecure=%t", r.Module, r.Version, suf, r.ImportedBy, r.Mode, r.Insecure)
-}
+// These methods implement queue.Task.
+func (r *Request) Name() string { return r.Module + "@" + r.Version }
func (r *Request) Path() string {
p := r.Module + "@" + r.Version
@@ -49,6 +44,10 @@
return p
}
+func (r *Request) Params() string {
+ return FormatParams(r.RequestParams)
+}
+
// ParseRequest parses an http request r for an endpoint
// scanPrefix and produces a corresponding ScanRequest.
//
@@ -257,12 +256,12 @@
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
paramName := strings.ToLower(f.Name)
- param := r.FormValue(paramName)
- if param == "" {
+ paramValue := r.FormValue(paramName)
+ if paramValue == "" {
// If param is missing, do not set field.
continue
}
- pval, err := parseParam(param, f.Type.Kind())
+ pval, err := parseParam(paramValue, f.Type.Kind())
if err != nil {
return fmt.Errorf("param %s: %v", paramName, err)
}
@@ -283,3 +282,24 @@
return nil, fmt.Errorf("cannot parse kind %s", kind)
}
}
+
+// FormatParams takes a struct or struct pointer, and returns
+// a URL query-param string with the struct field values.
+func FormatParams(s any) string {
+ v := reflect.ValueOf(s)
+ t := v.Type()
+ if t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ v = v.Elem()
+ }
+ if t.Kind() != reflect.Struct {
+ panic(fmt.Sprintf("need struct or struct pointer, got %T", s))
+ }
+ var params []string
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ params = append(params,
+ fmt.Sprintf("%s=%v", strings.ToLower(f.Name), v.Field(i)))
+ }
+ return strings.Join(params, "&")
+}
diff --git a/internal/scan/parse_test.go b/internal/scan/parse_test.go
index 19f86d2..f25967a 100644
--- a/internal/scan/parse_test.go
+++ b/internal/scan/parse_test.go
@@ -156,36 +156,36 @@
}
}
-func TestParseParams(t *testing.T) {
- type S struct {
- Str string
- Int int
- Bool bool
- }
+type params struct {
+ Str string
+ Int int
+ Bool bool
+}
+func TestParseParams(t *testing.T) {
t.Run("success", func(t *testing.T) {
for _, test := range []struct {
params string
- want S
+ want params
}{
{
"str=foo&int=1&bool=true",
- S{Str: "foo", Int: 1, Bool: true},
+ params{Str: "foo", Int: 1, Bool: true},
},
{
"", // all defaults
- S{Str: "d", Int: 17, Bool: false},
+ params{Str: "d", Int: 17, Bool: false},
},
{
"int=3&bool=t&str=", // empty string is same as default
- S{Str: "d", Int: 3, Bool: true},
+ params{Str: "d", Int: 3, Bool: true},
},
} {
r, err := http.NewRequest("GET", "https://path?"+test.params, nil)
if err != nil {
t.Fatal(err)
}
- got := S{Str: "d", Int: 17} // set defaults
+ got := params{Str: "d", Int: 17} // set defaults
if err := ParseParams(r, &got); err != nil {
t.Fatal(err)
}
@@ -201,8 +201,8 @@
errContains string
}{
{3, "", "struct pointer"},
- {&S{}, "int=foo", "invalid syntax"},
- {&S{}, "bool=foo", "invalid syntax"},
+ {¶ms{}, "int=foo", "invalid syntax"},
+ {¶ms{}, "bool=foo", "invalid syntax"},
{&struct{ F float64 }{}, "f=1.1", "cannot parse kind"},
} {
r, err := http.NewRequest("GET", "https://path?"+test.params, nil)
@@ -220,3 +220,11 @@
}
})
}
+
+func TestFormatParams(t *testing.T) {
+ got := FormatParams(params{Str: "foo", Int: 17, Bool: true})
+ want := "str=foo&int=17&bool=true"
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
diff --git a/internal/worker/enqueue.go b/internal/worker/enqueue.go
index d80f09f..804fab7 100644
--- a/internal/worker/enqueue.go
+++ b/internal/worker/enqueue.go
@@ -6,6 +6,7 @@
import (
"context"
+ "errors"
"sync"
"golang.org/x/pkgsite-metrics/internal/config"
@@ -48,10 +49,13 @@
sem := make(chan struct{}, concurrentEnqueues)
for _, sreq := range sreqs {
- log.Infof(ctx, "enqueuing: %s", sreq.URLPathAndParams())
+ log.Infof(ctx, "enqueuing: %s?%s", sreq.Path(), sreq.Params())
if sreq.Module == "std" {
continue // ignore the standard library
}
+ if sreq.Mode == "" {
+ return errors.New("ScanRequest.Mode cannot be empty")
+ }
sreq := sreq
sem <- struct{}{}
go func() {
diff --git a/internal/worker/server.go b/internal/worker/server.go
index 06e9ad2..1bb361f 100644
--- a/internal/worker/server.go
+++ b/internal/worker/server.go
@@ -26,7 +26,6 @@
"golang.org/x/pkgsite-metrics/internal/observe"
"golang.org/x/pkgsite-metrics/internal/proxy"
"golang.org/x/pkgsite-metrics/internal/queue"
- "golang.org/x/pkgsite-metrics/internal/scan"
vulnc "golang.org/x/vuln/client"
)
@@ -71,10 +70,10 @@
}
q, err := queue.New(ctx, cfg,
- func(ctx context.Context, sreq *scan.Request) (int, error) {
+ func(ctx context.Context, t queue.Task) (int, error) {
// When running locally, only the module path and version are
// printed for now.
- log.Infof(ctx, "enqueuing %s", sreq.URLPathAndParams())
+ log.Infof(ctx, "enqueuing %s?%s", t.Path(), t.Params())
return 0, nil
})
if err != nil {