internal/worker/store: simplify tests
Using generics, write a function "must1" that reduces clutter in tests
by handling errors in setup code.
Change-Id: Iedb6cadac11359254d3bdcf8495c4bf00ae9431b
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/393776
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/worker/store/store_test.go b/internal/worker/store/store_test.go
index 3344287..b5becb5 100644
--- a/internal/worker/store/store_test.go
+++ b/internal/worker/store/store_test.go
@@ -19,10 +19,22 @@
"golang.org/x/vulndb/internal/ghsa"
)
-func must(t *testing.T, err error) {
- t.Helper()
- if err != nil {
- t.Fatal(err)
+func must(err error) func(*testing.T) {
+ return func(t *testing.T) {
+ t.Helper()
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func must1[T any](x T, err error) func(*testing.T) T {
+ return func(t *testing.T) T {
+ t.Helper()
+ if err != nil {
+ t.Fatal(err)
+ }
+ return x
}
}
@@ -53,24 +65,21 @@
CommitHash: "abc",
NumTotal: 100,
}
- must(t, s.CreateCommitUpdateRecord(ctx, u1))
+ must(s.CreateCommitUpdateRecord(ctx, u1))(t)
u1.EndedAt = u1.StartedAt.Add(10 * time.Minute)
u1.NumAdded = 100
- must(t, s.SetCommitUpdateRecord(ctx, u1))
+ must(s.SetCommitUpdateRecord(ctx, u1))(t)
u2 := &CommitUpdateRecord{
StartedAt: start.Add(time.Hour),
CommitHash: "def",
NumTotal: 80,
}
- must(t, s.CreateCommitUpdateRecord(ctx, u2))
+ must(s.CreateCommitUpdateRecord(ctx, u2))(t)
u2.EndedAt = u2.StartedAt.Add(8 * time.Minute)
u2.NumAdded = 40
u2.NumModified = 40
- must(t, s.SetCommitUpdateRecord(ctx, u2))
- got, err := s.ListCommitUpdateRecords(ctx, 0)
- if err != nil {
- t.Fatal(err)
- }
+ must(s.SetCommitUpdateRecord(ctx, u2))(t)
+ got := must1(s.ListCommitUpdateRecords(ctx, 0))(t)
want := []*CommitUpdateRecord{u2, u1}
diff(t, want, got, cmpopts.IgnoreFields(CommitUpdateRecord{}, "UpdatedAt"))
for _, g := range got {
@@ -143,24 +152,14 @@
// Test SetCVERecord.
set := func(r *CVERecord) *CVERecord {
- err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+ must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
return tx.SetCVERecord(r)
- })
- if err != nil {
- t.Fatal(err)
- }
- r2, err := s.GetCVERecord(ctx, r.ID)
- if err != nil {
- t.Fatal(err)
- }
- return r2
+ }))(t)
+ return must1(s.GetCVERecord(ctx, r.ID))(t)
}
// Make sure the first record is the same that we created.
- got, err := s.GetCVERecord(ctx, id1)
- if err != nil {
- t.Fatal(err)
- }
+ got := must1(s.GetCVERecord(ctx, id1))(t)
diff(t, crs[0], got)
// Change the state and the commit hash.
@@ -172,31 +171,20 @@
want.CommitHash = "999"
diff(t, &want, got)
- gotNoAction, err := s.ListCVERecordsWithTriageState(ctx, TriageStateNoActionNeeded)
- if err != nil {
- t.Fatal(err)
- }
+ gotNoAction := must1(s.ListCVERecordsWithTriageState(ctx, TriageStateNoActionNeeded))(t)
diff(t, crs[1:], gotNoAction)
}
func testDirHashes(t *testing.T, s Store) {
ctx := context.Background()
const dir = "a/b/c"
- got, err := s.GetDirectoryHash(ctx, dir)
- if err != nil {
- t.Fatal(err)
- }
+ got := must1(s.GetDirectoryHash(ctx, dir))(t)
if got != "" {
t.Fatalf("got %q, want empty", got)
}
const want = "123"
- if err := s.SetDirectoryHash(ctx, "a/b/c", want); err != nil {
- t.Fatal(err)
- }
- got, err = s.GetDirectoryHash(ctx, dir)
- if err != nil {
- t.Fatal(err)
- }
+ must(s.SetDirectoryHash(ctx, "a/b/c", want))(t)
+ got = must1(s.GetDirectoryHash(ctx, dir))(t)
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
@@ -215,35 +203,26 @@
TriageState: TriageStateNeedsIssue,
},
}
- err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+ must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
for _, g := range gs {
if err := tx.CreateGHSARecord(g); err != nil {
return err
}
}
return nil
- })
- if err != nil {
- t.Fatal(err)
- }
+ }))(t)
// Modify one of them.
gs[1].TriageState = TriageStateIssueCreated
- err = s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+ must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
return tx.SetGHSARecord(gs[1])
- })
- if err != nil {
- t.Fatal(err)
- }
+ }))(t)
// Retrieve and compare.
var got []*GHSARecord
- err = s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+ must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
var err error
got, err = tx.GetGHSARecords()
return err
- })
- if err != nil {
- t.Fatal(err)
- }
+ }))(t)
if len(got) != len(gs) {
t.Fatalf("got %d records, want %d", len(got), len(gs))
}
@@ -254,14 +233,11 @@
// Retrieve one record by ID.
var got0 *GHSARecord
- err = s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+ must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
var err error
got0, err = tx.GetGHSARecord(gs[0].GetID())
return err
- })
- if err != nil {
- t.Fatal(err)
- }
+ }))(t)
if got, want := got0, gs[0]; !cmp.Equal(got, want) {
t.Errorf("got %+v, want %+v", got, want)
}
@@ -291,28 +267,23 @@
},
}
for _, r := range rs {
- if err := s.CreateModuleScanRecord(ctx, r); err != nil {
- t.Fatal(err)
- }
+ must(s.CreateModuleScanRecord(ctx, r))(t)
}
// GetModuleScanRecord
- got, err := s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm)
- if err != nil {
- t.Fatal(err)
- }
+ got := must1(s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm))(t)
// Expect the most recent.
if want := rs[1]; !cmp.Equal(got, want) {
t.Errorf("got\n%+v\nwant\n%+v", got, want)
}
// Non-existent record.
- got, err = s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm.Add(time.Second))
+ got, err := s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm.Add(time.Second))
if got != nil || err != nil {
t.Errorf("got (%v, %v), want (nil, nil)", got, err)
}
// ListModuleScanRecords
- got2, err := s.ListModuleScanRecords(ctx, 0)
+ got2 := must1(s.ListModuleScanRecords(ctx, 0))(t)
if err != nil {
t.Fatal(err)
}
@@ -324,17 +295,14 @@
}
func createCVERecords(t *testing.T, ctx context.Context, s Store, crs []*CVERecord) {
- err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+ must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
for _, cr := range crs {
if err := tx.CreateCVERecord(cr); err != nil {
return err
}
}
return nil
- })
- if err != nil {
- t.Fatal(err)
- }
+ }))(t)
}
func diff(t *testing.T, want, got interface{}, opts ...cmp.Option) {