internal/worker/store: simplify tests

Using generics, write a function "must1" that reduces clutter in tests
by handling errors in setup code.

Change-Id: Iedb6cadac11359254d3bdcf8495c4bf00ae9431b
Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/393776
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/worker/store/store_test.go b/internal/worker/store/store_test.go
index 3344287..b5becb5 100644
--- a/internal/worker/store/store_test.go
+++ b/internal/worker/store/store_test.go
@@ -19,10 +19,22 @@
 	"golang.org/x/vulndb/internal/ghsa"
 )
 
-func must(t *testing.T, err error) {
-	t.Helper()
-	if err != nil {
-		t.Fatal(err)
+func must(err error) func(*testing.T) {
+	return func(t *testing.T) {
+		t.Helper()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func must1[T any](x T, err error) func(*testing.T) T {
+	return func(t *testing.T) T {
+		t.Helper()
+		if err != nil {
+			t.Fatal(err)
+		}
+		return x
 	}
 }
 
@@ -53,24 +65,21 @@
 		CommitHash: "abc",
 		NumTotal:   100,
 	}
-	must(t, s.CreateCommitUpdateRecord(ctx, u1))
+	must(s.CreateCommitUpdateRecord(ctx, u1))(t)
 	u1.EndedAt = u1.StartedAt.Add(10 * time.Minute)
 	u1.NumAdded = 100
-	must(t, s.SetCommitUpdateRecord(ctx, u1))
+	must(s.SetCommitUpdateRecord(ctx, u1))(t)
 	u2 := &CommitUpdateRecord{
 		StartedAt:  start.Add(time.Hour),
 		CommitHash: "def",
 		NumTotal:   80,
 	}
-	must(t, s.CreateCommitUpdateRecord(ctx, u2))
+	must(s.CreateCommitUpdateRecord(ctx, u2))(t)
 	u2.EndedAt = u2.StartedAt.Add(8 * time.Minute)
 	u2.NumAdded = 40
 	u2.NumModified = 40
-	must(t, s.SetCommitUpdateRecord(ctx, u2))
-	got, err := s.ListCommitUpdateRecords(ctx, 0)
-	if err != nil {
-		t.Fatal(err)
-	}
+	must(s.SetCommitUpdateRecord(ctx, u2))(t)
+	got := must1(s.ListCommitUpdateRecords(ctx, 0))(t)
 	want := []*CommitUpdateRecord{u2, u1}
 	diff(t, want, got, cmpopts.IgnoreFields(CommitUpdateRecord{}, "UpdatedAt"))
 	for _, g := range got {
@@ -143,24 +152,14 @@
 	// Test SetCVERecord.
 
 	set := func(r *CVERecord) *CVERecord {
-		err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+		must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 			return tx.SetCVERecord(r)
-		})
-		if err != nil {
-			t.Fatal(err)
-		}
-		r2, err := s.GetCVERecord(ctx, r.ID)
-		if err != nil {
-			t.Fatal(err)
-		}
-		return r2
+		}))(t)
+		return must1(s.GetCVERecord(ctx, r.ID))(t)
 	}
 
 	// Make sure the first record is the same that we created.
-	got, err := s.GetCVERecord(ctx, id1)
-	if err != nil {
-		t.Fatal(err)
-	}
+	got := must1(s.GetCVERecord(ctx, id1))(t)
 	diff(t, crs[0], got)
 
 	// Change the state and the commit hash.
@@ -172,31 +171,20 @@
 	want.CommitHash = "999"
 	diff(t, &want, got)
 
-	gotNoAction, err := s.ListCVERecordsWithTriageState(ctx, TriageStateNoActionNeeded)
-	if err != nil {
-		t.Fatal(err)
-	}
+	gotNoAction := must1(s.ListCVERecordsWithTriageState(ctx, TriageStateNoActionNeeded))(t)
 	diff(t, crs[1:], gotNoAction)
 }
 
 func testDirHashes(t *testing.T, s Store) {
 	ctx := context.Background()
 	const dir = "a/b/c"
-	got, err := s.GetDirectoryHash(ctx, dir)
-	if err != nil {
-		t.Fatal(err)
-	}
+	got := must1(s.GetDirectoryHash(ctx, dir))(t)
 	if got != "" {
 		t.Fatalf("got %q, want empty", got)
 	}
 	const want = "123"
-	if err := s.SetDirectoryHash(ctx, "a/b/c", want); err != nil {
-		t.Fatal(err)
-	}
-	got, err = s.GetDirectoryHash(ctx, dir)
-	if err != nil {
-		t.Fatal(err)
-	}
+	must(s.SetDirectoryHash(ctx, "a/b/c", want))(t)
+	got = must1(s.GetDirectoryHash(ctx, dir))(t)
 	if got != want {
 		t.Fatalf("got %q, want %q", got, want)
 	}
@@ -215,35 +203,26 @@
 			TriageState: TriageStateNeedsIssue,
 		},
 	}
-	err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+	must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		for _, g := range gs {
 			if err := tx.CreateGHSARecord(g); err != nil {
 				return err
 			}
 		}
 		return nil
-	})
-	if err != nil {
-		t.Fatal(err)
-	}
+	}))(t)
 	// Modify one of them.
 	gs[1].TriageState = TriageStateIssueCreated
-	err = s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+	must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		return tx.SetGHSARecord(gs[1])
-	})
-	if err != nil {
-		t.Fatal(err)
-	}
+	}))(t)
 	// Retrieve and compare.
 	var got []*GHSARecord
-	err = s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+	must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		var err error
 		got, err = tx.GetGHSARecords()
 		return err
-	})
-	if err != nil {
-		t.Fatal(err)
-	}
+	}))(t)
 	if len(got) != len(gs) {
 		t.Fatalf("got %d records, want %d", len(got), len(gs))
 	}
@@ -254,14 +233,11 @@
 
 	// Retrieve one record by ID.
 	var got0 *GHSARecord
-	err = s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+	must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		var err error
 		got0, err = tx.GetGHSARecord(gs[0].GetID())
 		return err
-	})
-	if err != nil {
-		t.Fatal(err)
-	}
+	}))(t)
 	if got, want := got0, gs[0]; !cmp.Equal(got, want) {
 		t.Errorf("got %+v, want %+v", got, want)
 	}
@@ -291,28 +267,23 @@
 		},
 	}
 	for _, r := range rs {
-		if err := s.CreateModuleScanRecord(ctx, r); err != nil {
-			t.Fatal(err)
-		}
+		must(s.CreateModuleScanRecord(ctx, r))(t)
 	}
 
 	// GetModuleScanRecord
-	got, err := s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm)
-	if err != nil {
-		t.Fatal(err)
-	}
+	got := must1(s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm))(t)
 	// Expect the most recent.
 	if want := rs[1]; !cmp.Equal(got, want) {
 		t.Errorf("got\n%+v\nwant\n%+v", got, want)
 	}
 	// Non-existent record.
-	got, err = s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm.Add(time.Second))
+	got, err := s.GetModuleScanRecord(ctx, "m1", "v1.2.3", tm.Add(time.Second))
 	if got != nil || err != nil {
 		t.Errorf("got (%v, %v), want (nil, nil)", got, err)
 	}
 
 	// ListModuleScanRecords
-	got2, err := s.ListModuleScanRecords(ctx, 0)
+	got2 := must1(s.ListModuleScanRecords(ctx, 0))(t)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -324,17 +295,14 @@
 }
 
 func createCVERecords(t *testing.T, ctx context.Context, s Store, crs []*CVERecord) {
-	err := s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
+	must(s.RunTransaction(ctx, func(ctx context.Context, tx Transaction) error {
 		for _, cr := range crs {
 			if err := tx.CreateCVERecord(cr); err != nil {
 				return err
 			}
 		}
 		return nil
-	})
-	if err != nil {
-		t.Fatal(err)
-	}
+	}))(t)
 }
 
 func diff(t *testing.T, want, got interface{}, opts ...cmp.Option) {