{cmd/internal}/worker: add update command

Add a command to the worker CLI that updates the repo.

Add flags to read the repo locally and to use a pkgsite URL
other than the default (like staging, for testing).

By default, check that the update makes sense. Add a -force
flag to override.

Change-Id: Ib46963b84489a443ec018987bb58a2f457ae59ee
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/368596
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/cmd/cvetriage/worker.go b/cmd/cvetriage/worker.go
index f310050..a70f27d 100644
--- a/cmd/cvetriage/worker.go
+++ b/cmd/cvetriage/worker.go
@@ -29,7 +29,7 @@
 	if dirpath != "" {
 		repo, err = gitrepo.Open(dirpath)
 	} else {
-		repo, err = gitrepo.Clone(gitrepo.CVElistRepoURL)
+		repo, err = gitrepo.Clone(gitrepo.CVEListRepoURL)
 	}
 	if err != nil {
 		return err
diff --git a/cmd/worker/main.go b/cmd/worker/main.go
index 530cc03..301631f 100644
--- a/cmd/worker/main.go
+++ b/cmd/worker/main.go
@@ -18,6 +18,7 @@
 
 	"cloud.google.com/go/errorreporting"
 	"golang.org/x/vuln/internal/derrors"
+	"golang.org/x/vuln/internal/gitrepo"
 	"golang.org/x/vuln/internal/worker"
 	"golang.org/x/vuln/internal/worker/log"
 	"golang.org/x/vuln/internal/worker/store"
@@ -27,6 +28,9 @@
 	project        = flag.String("project", os.Getenv("GOOGLE_CLOUD_PROJECT"), "project ID")
 	namespace      = flag.String("namespace", os.Getenv("VULN_WORKER_NAMESPACE"), "Firestore namespace")
 	errorReporting = flag.Bool("reporterrors", os.Getenv("VULN_WORKER_REPORT_ERRORS") == "true", "use the error reporting API")
+	pkgsiteURL     = flag.String("pkgsite", "https://pkg.go.dev", "URL to pkgsite")
+	localRepoPath  = flag.String("repo", "", "path to local repo, instead of cloning remote")
+	force          = flag.Bool("force", false, "force an update to happen")
 )
 
 const serviceID = "vuln-worker"
@@ -88,6 +92,11 @@
 	switch flag.Arg(0) {
 	case "list-updates":
 		return listUpdatesCommand(ctx, st)
+	case "update":
+		if flag.NArg() != 2 {
+			return errors.New("usage: update COMMIT")
+		}
+		return updateCommand(ctx, st, flag.Arg(1))
 	default:
 		return fmt.Errorf("unknown command: %q", flag.Arg(1))
 	}
@@ -114,6 +123,18 @@
 	return tw.Flush()
 }
 
+func updateCommand(ctx context.Context, st store.Store, commitHash string) error {
+	repoPath := gitrepo.CVEListRepoURL
+	if *localRepoPath != "" {
+		repoPath = *localRepoPath
+	}
+	err := worker.UpdateCommit(ctx, repoPath, commitHash, st, *pkgsiteURL, *force)
+	if cerr := new(worker.CheckUpdateError); errors.As(err, &cerr) {
+		return fmt.Errorf("%w; use -force to override", cerr)
+	}
+	return err
+}
+
 func die(format string, args ...interface{}) {
 	fmt.Fprintf(os.Stderr, format, args...)
 	fmt.Fprintln(os.Stderr)
diff --git a/internal/gitrepo/gitrepo.go b/internal/gitrepo/gitrepo.go
index 4c40da0..ba68235 100644
--- a/internal/gitrepo/gitrepo.go
+++ b/internal/gitrepo/gitrepo.go
@@ -7,6 +7,7 @@
 
 import (
 	"context"
+	"strings"
 
 	"github.com/go-git/go-git/v5"
 	"github.com/go-git/go-git/v5/plumbing"
@@ -16,7 +17,7 @@
 	"golang.org/x/vuln/internal/worker/log"
 )
 
-const CVElistRepoURL = "https://github.com/CVEProject/cvelist"
+const CVEListRepoURL = "https://github.com/CVEProject/cvelist"
 
 // Clone returns a repo by cloning the repo at repoURL.
 func Clone(repoURL string) (repo *git.Repository, err error) {
@@ -42,6 +43,15 @@
 	return repo, nil
 }
 
+// CloneOrOpen clones repoPath if it is an HTTP(S) URL, or opens it from the
+// local disk otherwise.
+func CloneOrOpen(repoPath string) (*git.Repository, error) {
+	if strings.HasPrefix(repoPath, "http://") || strings.HasPrefix(repoPath, "https://") {
+		return Clone(repoPath)
+	}
+	return Open(repoPath)
+}
+
 // Root returns the root tree of the repo at HEAD.
 func Root(repo *git.Repository) (root *object.Tree, err error) {
 	refName := plumbing.HEAD
diff --git a/internal/worker/repo_test.go b/internal/worker/repo_test.go
index b1ae019..475d523 100644
--- a/internal/worker/repo_test.go
+++ b/internal/worker/repo_test.go
@@ -19,7 +19,7 @@
 
 // readTxtarRepo converts a txtar file to a single-commit
 // repo.
-func readTxtarRepo(filename string) (_ *git.Repository, err error) {
+func readTxtarRepo(filename string, now time.Time) (_ *git.Repository, err error) {
 	defer derrors.Wrap(&err, "readTxtarRepo(%q)", filename)
 
 	mfs := memfs.New()
@@ -56,7 +56,7 @@
 	_, err = wt.Commit("", &git.CommitOptions{All: true, Author: &object.Signature{
 		Name:  "Joe Random",
 		Email: "joe@example.com",
-		When:  time.Now(),
+		When:  now,
 	}})
 	if err != nil {
 		return nil, err
diff --git a/internal/worker/store/fire_store.go b/internal/worker/store/fire_store.go
index 90078df..50eb5dc 100644
--- a/internal/worker/store/fire_store.go
+++ b/internal/worker/store/fire_store.go
@@ -7,7 +7,6 @@
 import (
 	"context"
 	"errors"
-	"sort"
 
 	"cloud.google.com/go/firestore"
 	"golang.org/x/vuln/internal/derrors"
@@ -86,7 +85,7 @@
 // ListCommitUpdateRecords implements Store.ListCommitUpdateRecords.
 func (fs *FireStore) ListCommitUpdateRecords(ctx context.Context, limit int) ([]*CommitUpdateRecord, error) {
 	var urs []*CommitUpdateRecord
-	q := fs.nsDoc.Collection(updateCollection).Query
+	q := fs.nsDoc.Collection(updateCollection).OrderBy("StartedAt", firestore.Desc)
 	if limit > 0 {
 		q = q.Limit(limit)
 	}
@@ -106,9 +105,6 @@
 		ur.ID = docsnap.Ref.ID
 		urs = append(urs, &ur)
 	}
-	sort.Slice(urs, func(i, j int) bool {
-		return urs[i].StartedAt.After(urs[j].StartedAt)
-	})
 	return urs, nil
 }
 
diff --git a/internal/worker/triage.go b/internal/worker/triage.go
index 07ac2ca..37fc67c 100644
--- a/internal/worker/triage.go
+++ b/internal/worker/triage.go
@@ -40,11 +40,11 @@
 
 // TriageCVE reports whether the CVE refers to a
 // Go module.
-func TriageCVE(c *cveschema.CVE) (_ bool, err error) {
+func TriageCVE(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ bool, err error) {
 	defer derrors.Wrap(&err, "triageCVE(%q)", c.ID)
 	switch c.DataVersion {
 	case "4.0":
-		mp, err := cveModulePath(context.TODO(), c)
+		mp, err := cveModulePath(ctx, c, pkgsiteURL)
 		if err != nil {
 			return false, err
 		}
@@ -62,7 +62,7 @@
 // it is.
 // TODO(golang/go#49733) Use the CandidateModulePaths function from pkgsite to catch
 // longer module paths, e.g. github.com/pulumi/pulumi/sdk/v2.
-func cveModulePath(ctx context.Context, c *cveschema.CVE) (_ string, err error) {
+func cveModulePath(ctx context.Context, c *cveschema.CVE, pkgsiteURL string) (_ string, err error) {
 	defer derrors.Wrap(&err, "cveModulePath(%q)", c.ID)
 	for _, r := range c.References.Data {
 		if r.URL == "" {
@@ -87,7 +87,7 @@
 				continue
 			}
 			mod := strings.Join(parts[0:3], "/")
-			known, err := knownToPkgsite(ctx, "https://pkg.go.dev", mod)
+			known, err := knownToPkgsite(ctx, pkgsiteURL, mod)
 			if err != nil {
 				return "", err
 			}
diff --git a/internal/worker/update_test.go b/internal/worker/update_test.go
index fa490c9..a0d844e 100644
--- a/internal/worker/update_test.go
+++ b/internal/worker/update_test.go
@@ -9,6 +9,7 @@
 	"encoding/json"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/go-git/go-git/v5"
 	"github.com/go-git/go-git/v5/plumbing"
@@ -20,7 +21,7 @@
 )
 
 func TestRepoCVEFiles(t *testing.T) {
-	repo, err := readTxtarRepo("testdata/basic.txtar")
+	repo, err := readTxtarRepo("testdata/basic.txtar", time.Now())
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -46,7 +47,7 @@
 
 func TestDoUpdate(t *testing.T) {
 	ctx := log.WithLineLogger(context.Background())
-	repo, err := readTxtarRepo("testdata/basic.txtar")
+	repo, err := readTxtarRepo("testdata/basic.txtar", time.Now())
 	if err != nil {
 		t.Fatal(err)
 	}
diff --git a/internal/worker/worker.go b/internal/worker/worker.go
new file mode 100644
index 0000000..0c53eb9
--- /dev/null
+++ b/internal/worker/worker.go
@@ -0,0 +1,88 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package worker
+
+// This file has the public API of the worker, used by cmd/worker as well
+// as the server in this package.
+
+import (
+	"context"
+	"fmt"
+	"time"
+
+	"github.com/go-git/go-git/v5"
+	"github.com/go-git/go-git/v5/plumbing"
+	"golang.org/x/vuln/internal/cveschema"
+	"golang.org/x/vuln/internal/derrors"
+	"golang.org/x/vuln/internal/gitrepo"
+	"golang.org/x/vuln/internal/worker/store"
+)
+
+// UpdateCommit performs an update on the store using the given commit.
+// Unless force is true, it checks that the update makes sense before doing it.
+func UpdateCommit(ctx context.Context, repoPath, commitHash string, st store.Store, pkgsiteURL string, force bool) (err error) {
+	defer derrors.Wrap(&err, "RunCommitUpdate(%q, %q, force=%t)", repoPath, commitHash, force)
+
+	repo, err := gitrepo.CloneOrOpen(repoPath)
+	if err != nil {
+		return err
+	}
+	ch := plumbing.NewHash(commitHash)
+	if !force {
+		if err := checkUpdate(ctx, repo, ch, st); err != nil {
+			return err
+		}
+	}
+	return doUpdate(ctx, repo, ch, st, func(cve *cveschema.CVE) (bool, error) {
+		return TriageCVE(ctx, cve, pkgsiteURL)
+	})
+}
+
+// checkUpdate performs sanity checks on a potential update.
+// It verifies that there is not an update currently in progress,
+// and it makes sure that the update is to a more recent commit.
+func checkUpdate(ctx context.Context, repo *git.Repository, commitHash plumbing.Hash, st store.Store) error {
+	urs, err := st.ListCommitUpdateRecords(ctx, 1)
+	if err != nil {
+		return err
+	}
+	if len(urs) == 0 {
+		// No updates, we're good.
+		return nil
+	}
+	lu := urs[0]
+	if lu.EndedAt.IsZero() {
+		return &CheckUpdateError{
+			msg: fmt.Sprintf("latest update started %s ago and has not finished", time.Since(lu.StartedAt)),
+		}
+	}
+	if lu.Error != "" {
+		return &CheckUpdateError{
+			msg: fmt.Sprintf("latest update finished with error %q", lu.Error),
+		}
+	}
+	commit, err := repo.CommitObject(commitHash)
+	if err != nil {
+		return err
+	}
+	if commit.Committer.When.Before(lu.CommitTime) {
+		return &CheckUpdateError{
+			msg: fmt.Sprintf("commit %s time %s is before latest update commit %s time %s",
+				commitHash, commit.Committer.When.Format(time.RFC3339),
+				lu.CommitHash, lu.CommitTime.Format(time.RFC3339)),
+		}
+	}
+	return nil
+}
+
+// CheckUpdateError is an error returned from UpdateCommit that can be avoided
+// calling UpdateCommit with force set to true.
+type CheckUpdateError struct {
+	msg string
+}
+
+func (c *CheckUpdateError) Error() string {
+	return c.msg
+}
diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go
new file mode 100644
index 0000000..51768d5
--- /dev/null
+++ b/internal/worker/worker_test.go
@@ -0,0 +1,79 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package worker
+
+import (
+	"context"
+	"strings"
+	"testing"
+	"time"
+
+	"golang.org/x/vuln/internal/worker/store"
+)
+
+func TestCheckUpdate(t *testing.T) {
+	ctx := context.Background()
+	tm := time.Date(2021, 1, 26, 0, 0, 0, 0, time.Local)
+	repo, err := readTxtarRepo("testdata/basic.txtar", tm)
+	if err != nil {
+		t.Fatal(err)
+	}
+	for _, test := range []struct {
+		latestUpdate *store.CommitUpdateRecord
+		want         string // non-empty => substring of error message
+	}{
+		// no latest update, no problem
+		{nil, ""},
+		// latest update finished and commit is earlier; no problem
+		{
+			&store.CommitUpdateRecord{
+				EndedAt:    time.Now(),
+				CommitHash: "abc",
+				CommitTime: tm.Add(-time.Hour),
+			},
+			"",
+		},
+		// latest update didn't finish
+		{
+			&store.CommitUpdateRecord{
+				CommitHash: "abc",
+				CommitTime: tm.Add(-time.Hour),
+			},
+			"not finish",
+		},
+		// latest update finished with error
+		{
+			&store.CommitUpdateRecord{
+				CommitHash: "abc",
+				CommitTime: tm.Add(-time.Hour),
+				EndedAt:    time.Now(),
+				Error:      "bad",
+			},
+			"with error",
+		},
+		// latest update finished on a later commit
+		{
+			&store.CommitUpdateRecord{
+				EndedAt:    time.Now(),
+				CommitHash: "abc",
+				CommitTime: tm.Add(time.Hour),
+			},
+			"before",
+		},
+	} {
+		mstore := store.NewMemStore()
+		if test.latestUpdate != nil {
+			if err := mstore.CreateCommitUpdateRecord(ctx, test.latestUpdate); err != nil {
+				t.Fatal(err)
+			}
+		}
+		got := checkUpdate(ctx, repo, headCommit(t, repo).Hash, mstore)
+		if got == nil && test.want != "" {
+			t.Errorf("%+v:\ngot no error, wanted %q", test.latestUpdate, test.want)
+		} else if got != nil && !strings.Contains(got.Error(), test.want) {
+			t.Errorf("%+v:\ngot '%s', does not contain %q", test.latestUpdate, got, test.want)
+		}
+	}
+}