blob: 5e7f517a730c51ead62e0cd946100b1209e0ea70 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package worker
import (
"context"
"encoding/json"
"strings"
"testing"
"time"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/vuln/internal/cveschema"
"golang.org/x/vuln/internal/worker/log"
"golang.org/x/vuln/internal/worker/store"
)
func TestRepoCVEFiles(t *testing.T) {
repo, err := readTxtarRepo("testdata/basic.txtar", time.Now())
if err != nil {
t.Fatal(err)
}
commit := headCommit(t, repo)
if err != nil {
t.Fatal(err)
}
got, err := repoCVEFiles(repo, commit)
if err != nil {
t.Fatal(err)
}
want := []repoFile{
{dirPath: "2021/0xxx", filename: "CVE-2021-0001.json", year: 2021, number: 1},
{dirPath: "2021/0xxx", filename: "CVE-2021-0010.json", year: 2021, number: 10},
{dirPath: "2021/1xxx", filename: "CVE-2021-1384.json", year: 2021, number: 1384},
}
opt := cmpopts.IgnoreFields(repoFile{}, "treeHash", "blobHash")
if diff := cmp.Diff(want, got, cmp.AllowUnexported(repoFile{}), opt); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}
func TestDoUpdate(t *testing.T) {
ctx := log.WithLineLogger(context.Background())
repo, err := readTxtarRepo("testdata/basic.txtar", time.Now())
if err != nil {
t.Fatal(err)
}
h, err := headHash(repo)
if err != nil {
t.Fatal(err)
}
needsIssue := func(cve *cveschema.CVE) (bool, error) {
return strings.HasSuffix(cve.ID, "0001"), nil
}
ref, err := repo.Reference(plumbing.HEAD, true)
if err != nil {
t.Fatal(err)
}
commitHash := ref.Hash().String()
const (
path1 = "2021/0xxx/CVE-2021-0001.json"
path2 = "2021/0xxx/CVE-2021-0010.json"
path3 = "2021/1xxx/CVE-2021-1384.json"
)
cve1, bh1 := readCVE(t, repo, path1)
cve2, bh2 := readCVE(t, repo, path2)
cve3, bh3 := readCVE(t, repo, path3)
// CVERecords after the above CVEs are added to an empty DB.
rs := []*store.CVERecord{
{
ID: cve1.ID,
CVEState: cve1.State,
Path: path1,
BlobHash: bh1,
CommitHash: commitHash,
TriageState: store.TriageStateNeedsIssue, // a public CVE, needsIssue returns true
},
{
ID: cve2.ID,
CVEState: cve2.State,
Path: path2,
BlobHash: bh2,
CommitHash: commitHash,
TriageState: store.TriageStateNoActionNeeded, // state is reserved
},
{
ID: cve3.ID,
CVEState: cve3.State,
Path: path3,
BlobHash: bh3,
CommitHash: commitHash,
TriageState: store.TriageStateNoActionNeeded, // state is rejected
},
}
// withTriageState returns a copy of r with the TriageState field changed to ts.
withTriageState := func(r *store.CVERecord, ts store.TriageState) *store.CVERecord {
c := *r
c.BlobHash += "x" // if we don't use a different blob hash, no update will happen
c.CommitHash = "?"
c.TriageState = ts
return &c
}
for _, test := range []struct {
name string
cur []*store.CVERecord // current state of DB
want []*store.CVERecord // expected state after update
}{
{
name: "empty",
cur: nil,
want: rs,
},
{
name: "no change",
cur: rs,
want: rs,
},
{
name: "pre-issue changes",
cur: []*store.CVERecord{
// NoActionNeeded -> NeedsIssue
withTriageState(rs[0], store.TriageStateNoActionNeeded),
// NeedsIssue -> NoActionNeeded
withTriageState(rs[1], store.TriageStateNeedsIssue),
// NoActionNeeded, triage state stays the same but other fields change.
withTriageState(rs[2], store.TriageStateNoActionNeeded),
},
want: rs,
},
{
name: "post-issue changes",
cur: []*store.CVERecord{
// IssueCreated -> Updated
withTriageState(rs[0], store.TriageStateIssueCreated),
withTriageState(rs[1], store.TriageStateUpdatedSinceIssueCreation),
},
want: []*store.CVERecord{
func() *store.CVERecord {
c := *rs[0]
c.TriageState = store.TriageStateUpdatedSinceIssueCreation
c.TriageStateReason = "CVE changed; needs issue = true"
return &c
}(),
func() *store.CVERecord {
c := *rs[1]
c.TriageState = store.TriageStateUpdatedSinceIssueCreation
c.TriageStateReason = "CVE changed; needs issue = false"
return &c
}(),
rs[2],
},
},
} {
t.Run(test.name, func(t *testing.T) {
mstore := store.NewMemStore()
createCVERecords(t, mstore, test.cur)
if _, err := doUpdate(ctx, repo, h, mstore, needsIssue); err != nil {
t.Fatal(err)
}
got := mstore.CVERecords()
want := map[string]*store.CVERecord{}
for _, cr := range test.want {
want[cr.ID] = cr
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
})
}
}
func TestGroupFilesByDirectory(t *testing.T) {
for _, test := range []struct {
in []repoFile
want [][]repoFile
}{
{in: nil, want: nil},
{
in: []repoFile{{dirPath: "a"}},
want: [][]repoFile{{{dirPath: "a"}}},
},
{
in: []repoFile{
{dirPath: "a", filename: "f1"},
{dirPath: "a", filename: "f2"},
},
want: [][]repoFile{{
{dirPath: "a", filename: "f1"},
{dirPath: "a", filename: "f2"},
}},
},
{
in: []repoFile{
{dirPath: "a", filename: "f1"},
{dirPath: "a", filename: "f2"},
{dirPath: "b", filename: "f1"},
{dirPath: "c", filename: "f1"},
{dirPath: "c", filename: "f2"},
},
want: [][]repoFile{
{
{dirPath: "a", filename: "f1"},
{dirPath: "a", filename: "f2"},
},
{
{dirPath: "b", filename: "f1"},
},
{
{dirPath: "c", filename: "f1"},
{dirPath: "c", filename: "f2"},
},
},
},
} {
got, err := groupFilesByDirectory(test.in)
if err != nil {
t.Fatalf("%v: %v", test.in, err)
}
if diff := cmp.Diff(got, test.want, cmp.AllowUnexported(repoFile{})); diff != "" {
t.Errorf("%v: (-want, +got)\n%s", test.in, diff)
}
}
_, err := groupFilesByDirectory([]repoFile{{dirPath: "a"}, {dirPath: "b"}, {dirPath: "a"}})
if err == nil {
t.Error("got nil, want error")
}
}
func readCVE(t *testing.T, repo *git.Repository, path string) (*cveschema.CVE, string) {
c := headCommit(t, repo)
file, err := c.File(path)
if err != nil {
t.Fatal(err)
}
var cve cveschema.CVE
r, err := file.Reader()
if err != nil {
t.Fatal(err)
}
if err := json.NewDecoder(r).Decode(&cve); err != nil {
t.Fatal(err)
}
return &cve, file.Hash.String()
}
func createCVERecords(t *testing.T, s store.Store, crs []*store.CVERecord) {
err := s.RunTransaction(context.Background(), func(_ context.Context, tx store.Transaction) error {
for _, cr := range crs {
if err := tx.CreateCVERecord(cr); err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
}