| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package tasks provides a Google Cloud Tasks implementation |
| // as a queue used for asynchronous scheduling of fetch actions. |
| package tasks |
| |
| import ( |
| "context" |
| "crypto/sha256" |
| "encoding/hex" |
| "errors" |
| "fmt" |
| "io" |
| "strings" |
| "time" |
| |
| cloudtasks "cloud.google.com/go/cloudtasks/apiv2" |
| taskspb "cloud.google.com/go/cloudtasks/apiv2/cloudtaskspb" |
| "golang.org/x/oscar/internal/queue" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| "google.golang.org/protobuf/types/known/durationpb" |
| ) |
| |
| // New creates a new Cloud Tasks [queue.Queue] based on metadata m. |
| func New(ctx context.Context, m *queue.Metadata) (queue.Queue, error) { |
| client, err := cloudtasks.NewClient(ctx) |
| if err != nil { |
| return nil, err |
| } |
| g, err := newQueue(client, m) |
| if err != nil { |
| return nil, err |
| } |
| return g, nil |
| } |
| |
| // Queue is a [queue.Queue] backed by the Google Cloud Tasks. |
| type Queue struct { |
| client *cloudtasks.Client |
| name string // full GCP name of the queue |
| url string // non-AppEngine URL to post tasks to |
| // token holds information that lets the task queue construct an authorized request to the worker. |
| // Since the worker sits behind the IAP, the queue needs an identity token that includes the |
| // identity of a service account that has access, and the client ID for the IAP. |
| // We use the service account of the current process. |
| token *taskspb.HttpRequest_OidcToken |
| } |
| |
| // newQueue returns a new Queue based on metadata m that can be used to |
| // enqueue tasks using the cloud tasks API. The given m.QueueName |
| // should be the name of the queue in the cloud tasks console. |
| func newQueue(client *cloudtasks.Client, m *queue.Metadata) (*Queue, error) { |
| if m.QueueName == "" { |
| return nil, errors.New("empty queue name") |
| } |
| if m.Project == "" { |
| return nil, errors.New("empty project") |
| } |
| if m.QueueURL == "" { |
| return nil, errors.New("empty queue URL") |
| } |
| if m.ServiceAccount == "" { |
| return nil, errors.New("empty serviceAccount") |
| } |
| if m.Location == "" { |
| return nil, errors.New("empty location") |
| } |
| return &Queue{ |
| client: client, |
| name: fmt.Sprintf("projects/%s/locations/%s/queues/%s", m.Project, m.Location, m.QueueName), |
| url: m.QueueURL, |
| token: &taskspb.HttpRequest_OidcToken{ |
| OidcToken: &taskspb.OidcToken{ |
| ServiceAccountEmail: m.ServiceAccount, |
| }, |
| }, |
| }, nil |
| } |
| |
| // Enqueue enqueues a task in q. |
| // It returns an error if there was an error hashing the task name, or |
| // an error pushing the task to GCP. |
| // If the task was a duplicate, it returns (false, nil). |
| func (q *Queue) Enqueue(ctx context.Context, task queue.Task, opts *queue.Options) (bool, error) { |
| if opts == nil { |
| opts = &queue.Options{} |
| } |
| // Cloud Tasks enforces an RPC timeout of at most 30s. I couldn't find this |
| // in the documentation, but using a larger value, or no timeout, results in |
| // an InvalidArgument error with the text "The deadline cannot be more than |
| // 30s in the future." |
| ctx, cancel := context.WithTimeout(ctx, 30*time.Second) |
| defer cancel() |
| |
| req, err := q.newTaskRequest(task, opts) |
| if err != nil { |
| return false, fmt.Errorf("newTaskRequest: %v", err) |
| } |
| |
| _, err = q.client.CreateTask(ctx, req) |
| if err == nil { |
| return true, nil |
| } |
| if status.Code(err) == codes.AlreadyExists { |
| return false, nil |
| } |
| return false, fmt.Errorf("q.client.CreateTask(ctx, req): %v", err) |
| } |
| |
| // maxCloudTasksTimeout is the maximum timeout for HTTP tasks. |
| // See https://cloud.google.com/tasks/docs/creating-http-target-tasks. |
| const maxCloudTasksTimeout = 30 * time.Minute |
| |
| func (q *Queue) newTaskRequest(task queue.Task, opts *queue.Options) (*taskspb.CreateTaskRequest, error) { |
| relativeURI := "/" + task.Path() |
| if params := task.Params(); params != "" { |
| relativeURI += "?" + params |
| } |
| |
| taskID := newTaskID(task) |
| taskpb := &taskspb.Task{ |
| Name: fmt.Sprintf("%s/tasks/%s", q.name, taskID), |
| DispatchDeadline: durationpb.New(maxCloudTasksTimeout), |
| MessageType: &taskspb.Task_HttpRequest{ |
| HttpRequest: &taskspb.HttpRequest{ |
| HttpMethod: taskspb.HttpMethod_POST, |
| Url: q.url + relativeURI, |
| AuthorizationHeader: q.token, |
| }, |
| }, |
| } |
| req := &taskspb.CreateTaskRequest{ |
| Parent: q.name, |
| Task: taskpb, |
| } |
| // If suffix is non-empty, append it to the task name. |
| // This lets us force reprocessing of tasks that would normally be de-duplicated. |
| if opts.TaskNameSuffix != "" { |
| req.Task.Name += "-" + opts.TaskNameSuffix |
| } |
| return req, nil |
| } |
| |
| // newTaskID creates a task ID for the given task. |
| // Tasks with the same ID that are created within a few hours of each other will be de-duplicated. |
| // See https://cloud.google.com/tasks/docs/reference/rpc/google.cloud.tasks.v2#createtaskrequest |
| // under "Task De-duplication". |
| func newTaskID(task queue.Task) string { |
| name := task.Name() |
| // Hash the path, params, and body of the task. |
| hasher := sha256.New() |
| io.WriteString(hasher, task.Path()) |
| io.WriteString(hasher, task.Params()) |
| hash := hex.EncodeToString(hasher.Sum(nil)) |
| return escapeTaskID(fmt.Sprintf("%s-%s", name, hash[:8])) |
| } |
| |
| // escapeTaskID escapes s so it contains only valid characters for a Cloud Tasks name. |
| // It tries to produce a readable result. |
| // Task IDs can contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-), or underscores (_). |
| func escapeTaskID(s string) string { |
| var b strings.Builder |
| for _, r := range s { |
| switch { |
| case r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-': |
| b.WriteRune(r) |
| case r == '_': |
| b.WriteString("__") |
| case r == '/': |
| b.WriteString("_-") |
| case r == '@': |
| b.WriteString("_") |
| case r == '.': |
| b.WriteString("_") |
| default: |
| fmt.Fprintf(&b, "_%04x", r) |
| } |
| } |
| return b.String() |
| } |