internal/worker: refactor fetch_test.go

Tests for FetchAndUpdateState are refactored, since there was a lot of
duplicated logic across these test. These tests now use the
checkPackageVersionStates and fetchAndCheckStatus helper functions to
call FetchAndUpdateState and check the corresponding output.

Change-Id: I431b0cb3862b5b76b6c3cf85cb07fa9a9f11d941
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/247192
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/worker/fetch_test.go b/internal/worker/fetch_test.go
index bebbced..691faa1 100644
--- a/internal/worker/fetch_test.go
+++ b/internal/worker/fetch_test.go
@@ -18,6 +18,7 @@
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"github.com/google/safehtml/testconversions"
+	"golang.org/x/mod/semver"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/fetch"
@@ -57,10 +58,9 @@
 func TestFetchAndUpdateState_NotFound(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 	defer cancel()
-
 	defer postgres.ResetTestDB(testDB, t)
 
-	proxyClient, teardownProxy := proxy.SetupTestProxy(t, []*proxy.Module{
+	proxyClient, teardown := proxy.SetupTestProxy(t, []*proxy.Module{
 		{
 			ModulePath: sample.ModulePath,
 			Version:    sample.VersionString,
@@ -71,57 +71,8 @@
 			},
 		},
 	})
-	sourceClient := source.NewClient(sourceTimeout)
-
-	checkStatus := func(want int) {
-		t.Helper()
-		vs, err := testDB.GetModuleVersionState(ctx, sample.ModulePath, sample.VersionString)
-		if err != nil {
-			t.Fatal(err)
-		}
-		if vs.Status != want {
-			t.Fatalf("testDB.GetModuleVersionState(ctx, %q, %q): status = %v, want = %d", sample.ModulePath, sample.VersionString, vs.Status, want)
-		}
-		if want != http.StatusNotFound {
-			vm, err := testDB.GetVersionMap(ctx, sample.ModulePath, sample.VersionString)
-			if err != nil {
-				t.Fatal(err)
-			}
-			if vm.Status != want {
-				t.Fatalf("testDB.GetVersionMap(ctx, %q, %q): status = %d, want = %d", sample.ModulePath, sample.VersionString, vm.Status, want)
-			}
-		}
-	}
-
-	// Fetch a module@version that the proxy serves successfully.
-	if _, err := FetchAndUpdateState(ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, testAppVersion); err != nil {
-		t.Fatal(err)
-	}
-
-	// Verify that the module status is recorded correctly, and that the version is in the DB.
-	checkStatus(http.StatusOK)
-
-	if _, err := testDB.GetModuleInfo(ctx, sample.ModulePath, sample.VersionString); err != nil {
-		t.Fatal(err)
-	}
-
-	gotStates, err := testDB.GetPackageVersionStatesForModule(ctx, sample.ModulePath, sample.VersionString)
-	if err != nil {
-		t.Fatal(err)
-	}
-	wantStates := []*internal.PackageVersionState{
-		{
-			PackagePath: sample.ModulePath + "/foo",
-			ModulePath:  sample.ModulePath,
-			Version:     sample.VersionString,
-			Status:      http.StatusOK,
-		},
-	}
-	if diff := cmp.Diff(wantStates, gotStates); diff != "" {
-		t.Errorf("testDB.GetPackageVersionStatesForModule(ctx, %q, %q) mismatch (-want +got):\n%s", sample.ModulePath, sample.VersionString, diff)
-	}
-
-	teardownProxy()
+	defer teardown()
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString, http.StatusOK)
 
 	// Take down the module, by having the proxy serve a 404/410 for it.
 	proxyServer := proxy.NewServer([]*proxy.Module{}) // serve no versions, not even the defaults.
@@ -131,21 +82,16 @@
 	proxyClient, teardownProxy2 := proxy.TestProxyServer(t, proxyServer)
 	defer teardownProxy2()
 
-	// Now fetch it again.
-	if code, _ := FetchAndUpdateState(ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, testAppVersion); code != http.StatusNotFound {
-		t.Fatalf("FetchAndUpdateState(ctx, %q, %q, proxyClient, sourceClient, testDB): got code %d, want 404/410", sample.ModulePath, sample.VersionString, code)
-	}
-
-	// The new state should have a status of Not Found.
-	checkStatus(http.StatusNotFound)
-
-	gotStates, err = testDB.GetPackageVersionStatesForModule(ctx, sample.ModulePath, sample.VersionString)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if diff := cmp.Diff(wantStates, gotStates); diff != "" {
-		t.Errorf("testDB.GetPackageVersionStatesForModule(ctx, %q, %q) mismatch (-want +got):\n%s", sample.ModulePath, sample.VersionString, diff)
-	}
+	// Now fetch it again. The new state should have a status of Not Found.
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString, http.StatusNotFound)
+	checkPackageVersionStates(ctx, t, sample.ModulePath, sample.VersionString, []*internal.PackageVersionState{
+		{
+			PackagePath: sample.ModulePath + "/foo",
+			ModulePath:  sample.ModulePath,
+			Version:     sample.VersionString,
+			Status:      http.StatusOK,
+		},
+	})
 
 	// The module should no longer be in the database:
 	// - It shouldn't be in the modules table. That also covers licenses, packages and paths tables
@@ -154,7 +100,6 @@
 	if _, err := testDB.GetModuleInfo(ctx, sample.ModulePath, sample.VersionString); !errors.Is(err, derrors.NotFound) {
 		t.Fatalf("GetModuleInfo: got %v, want NotFound", err)
 	}
-
 	checkNotInTable := func(table, column string) {
 		q := fmt.Sprintf("SELECT 1 FROM %s WHERE %s = $1 LIMIT 1", table, column)
 		var x int
@@ -163,7 +108,6 @@
 			t.Errorf("table %s: got %v, want ErrNoRows", table, err)
 		}
 	}
-
 	checkNotInTable("search_documents", "module_path")
 	checkNotInTable("imports_unique", "from_module_path")
 	checkNotInTable("imports", "from_module_path")
@@ -178,39 +122,12 @@
 
 	proxyClient, teardownProxy := proxy.SetupTestProxy(t, nil)
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
 
 	if err := testDB.InsertExcludedPrefix(ctx, sample.ModulePath, "user", "for testing"); err != nil {
 		t.Fatal(err)
 	}
 
-	checkModuleNotFound(t, ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, http.StatusForbidden, derrors.Excluded)
-}
-
-func checkModuleNotFound(t *testing.T, ctx context.Context, modulePath, version string, proxyClient *proxy.Client, sourceClient *source.Client, wantCode int, wantErr error) {
-	t.Helper()
-	code, err := FetchAndUpdateState(ctx, modulePath, version, proxyClient, sourceClient, testDB, "appVersionLabel")
-	if code != wantCode || !errors.Is(err, wantErr) {
-		t.Fatalf("got %d, %v; want %d, Is(err, %v)", code, err, wantCode, wantErr)
-	}
-	_, err = testDB.GetModuleInfo(ctx, modulePath, version)
-	if !errors.Is(err, derrors.NotFound) {
-		t.Fatalf("got %v, want Is(NotFound)", err)
-	}
-	vs, err := testDB.GetModuleVersionState(ctx, modulePath, version)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vs.Status != wantCode {
-		t.Fatalf("testDB.GetModuleVersionState(ctx, %q, %q): status=%v, want %d", modulePath, version, vs.Status, wantCode)
-	}
-	vm, err := testDB.GetVersionMap(ctx, modulePath, version)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vm.Status != wantCode {
-		t.Fatalf("testDB.GetVersionMap(ctx, %q, %q): status=%d; want %d", modulePath, version, vm.Status, wantCode)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString, http.StatusForbidden)
 }
 
 func TestFetchAndUpdateState_BadRequestedVersion(t *testing.T) {
@@ -224,24 +141,7 @@
 	)
 	proxyClient, teardownProxy := proxy.SetupTestProxy(t, []*proxy.Module{buildConstraintsMod})
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
-
-	want := http.StatusNotFound
-	code, _ := FetchAndUpdateState(ctx, modulePath, version, proxyClient, sourceClient, testDB, testAppVersion)
-	if code != want {
-		t.Fatalf("got code %d, want %d", code, want)
-	}
-	_, err := testDB.GetModuleVersionState(ctx, modulePath, version)
-	if !errors.Is(err, derrors.NotFound) {
-		t.Fatal(err)
-	}
-	vm, err := testDB.GetVersionMap(ctx, modulePath, version)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vm.Status != http.StatusNotFound {
-		t.Fatalf("testDB.GetVersionMap(ctx, %q, %q): status=%v, want %d", modulePath, version, code, want)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, modulePath, version, http.StatusNotFound)
 }
 
 func TestFetchAndUpdateState_Incomplete(t *testing.T) {
@@ -253,59 +153,22 @@
 
 	proxyClient, teardownProxy := proxy.SetupTestProxy(t, []*proxy.Module{buildConstraintsMod})
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
 
-	var (
-		modulePath = buildConstraintsMod.ModulePath
-		version    = buildConstraintsMod.Version
-		want       = hasIncompletePackagesCode
-	)
-
-	code, err := FetchAndUpdateState(ctx, modulePath, version, proxyClient, sourceClient, testDB, testAppVersion)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if code != want {
-		t.Fatalf("got code %d, want %d", code, want)
-	}
-	vs, err := testDB.GetModuleVersionState(ctx, modulePath, version)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vs.Status != want {
-		t.Fatalf("testDB.GetModuleVersionState(ctx, %q, %q): status=%v, want %d", modulePath, version, vs.Status, want)
-	}
-	vm, err := testDB.GetVersionMap(ctx, modulePath, version)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vm.Status != want {
-		t.Fatalf("testDB.GetVersionMap(ctx,  %q, %q): status=%v, want %d", modulePath, version, vm.Status, want)
-	}
-	gotStates, err := testDB.GetPackageVersionStatesForModule(ctx, modulePath, version)
-	if err != nil {
-		t.Fatal(err)
-	}
-	wantStates := []*internal.PackageVersionState{
+	fetchAndCheckStatus(ctx, t, proxyClient, buildConstraintsMod.ModulePath, buildConstraintsMod.Version, hasIncompletePackagesCode)
+	checkPackageVersionStates(ctx, t, buildConstraintsMod.ModulePath, buildConstraintsMod.Version, []*internal.PackageVersionState{
 		{
-			PackagePath: modulePath + "/cpu",
-			ModulePath:  modulePath,
-			Version:     version,
+			PackagePath: buildConstraintsMod.ModulePath + "/cpu",
+			ModulePath:  buildConstraintsMod.ModulePath,
+			Version:     buildConstraintsMod.Version,
 			Status:      200,
 		},
 		{
-			PackagePath: modulePath + "/ignore",
-			ModulePath:  modulePath,
-			Version:     version,
+			PackagePath: buildConstraintsMod.ModulePath + "/ignore",
+			ModulePath:  buildConstraintsMod.ModulePath,
+			Version:     buildConstraintsMod.Version,
 			Status:      600,
 		},
-	}
-	sort.Slice(gotStates, func(i, j int) bool {
-		return gotStates[i].PackagePath < gotStates[j].PackagePath
 	})
-	if diff := cmp.Diff(wantStates, gotStates); diff != "" {
-		t.Errorf("testDB.GetPackageVersionStatesForModule(ctx, %q, %q) mismatch (-want +got):\n%s", modulePath, version, diff)
-	}
 }
 
 func TestFetchAndUpdateState_Mismatch(t *testing.T) {
@@ -327,35 +190,9 @@
 		},
 	})
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
 
-	code, err := FetchAndUpdateState(ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, testAppVersion)
-	wantErr := derrors.AlternativeModule
-	wantCode := derrors.ToStatus(wantErr)
-	if code != wantCode || !errors.Is(err, wantErr) {
-		t.Fatalf("got %d, %v; want %d, Is(err, derrors.AlternativeModule)", code, err, wantCode)
-	}
-	_, err = testDB.GetModuleInfo(ctx, sample.ModulePath, sample.VersionString)
-	if !errors.Is(err, derrors.NotFound) {
-		t.Fatalf("got %v, want Is(NotFound)", err)
-	}
-	vs, err := testDB.GetModuleVersionState(ctx, sample.ModulePath, sample.VersionString)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vs.Status != wantCode {
-		t.Errorf("testDB.GetModuleVersionState(ctx, %q, %q): status=%v, want %d", sample.ModulePath, sample.VersionString, vs.Status, wantCode)
-	}
-	if vs.GoModPath != goModPath {
-		t.Errorf("testDB.GetModuleVersionState(ctx, %q, %q): goModPath=%q, want %q", sample.ModulePath, sample.VersionString, vs.GoModPath, goModPath)
-	}
-	vm, err := testDB.GetVersionMap(ctx, sample.ModulePath, sample.VersionString)
-	if err != nil {
-		t.Fatal(err)
-	}
-	if vm.Status != wantCode {
-		t.Fatalf("testDB.GetVersionMap(ctx, %q, %q): status=%d, want %d", sample.ModulePath, sample.VersionString, vm.Status, wantCode)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString,
+		derrors.ToStatus(derrors.AlternativeModule))
 }
 
 func TestFetchAndUpdateState_DeleteOlder(t *testing.T) {
@@ -391,21 +228,14 @@
 		},
 	})
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
 
-	if _, err := FetchAndUpdateState(ctx, sample.ModulePath, olderVersion, proxyClient, sourceClient, testDB, testAppVersion); err != nil {
-		t.Fatal(err)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, olderVersion, http.StatusOK)
 	gotModule, gotVersion, gotFound := postgres.GetFromSearchDocuments(ctx, t, testDB, sample.ModulePath+"/foo")
 	if !gotFound || gotModule != sample.ModulePath || gotVersion != olderVersion {
 		t.Fatalf("got (%q, %q, %t), want (%q, %q, true)", gotModule, gotVersion, gotFound, sample.ModulePath, olderVersion)
 	}
 
-	code, _ := FetchAndUpdateState(ctx, sample.ModulePath, mismatchVersion, proxyClient, sourceClient, testDB, testAppVersion)
-	if want := derrors.ToStatus(derrors.AlternativeModule); code != want {
-		t.Fatalf("got %d, want %d", code, want)
-	}
-
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, mismatchVersion, derrors.ToStatus(derrors.AlternativeModule))
 	if _, _, gotFound := postgres.GetFromSearchDocuments(ctx, t, testDB, sample.ModulePath+"/foo"); gotFound {
 		t.Fatal("older version found in search documents")
 	}
@@ -439,16 +269,7 @@
 		},
 	})
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
-
-	code, err := FetchAndUpdateState(ctx, modulePath, version, proxyClient, sourceClient, testDB, testAppVersion)
-	if err != nil {
-		t.Fatalf("FetchAndUpdateState(%q, %q, %v, %v, %v): %v", modulePath, version, proxyClient, sourceClient, testDB, err)
-	}
-	if code != hasIncompletePackagesCode {
-		t.Errorf("FetchAndUpdateState(%q, %q, %v, %v, %v): hasIncompletePackages=false, want true",
-			modulePath, version, proxyClient, sourceClient, testDB)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, modulePath, version, hasIncompletePackagesCode)
 
 	pkgFoo := modulePath + "/foo"
 	if _, err := testDB.LegacyGetPackage(ctx, pkgFoo, internal.UnknownModulePath, version); err != nil {
@@ -503,16 +324,8 @@
 		},
 	})
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
 
-	code, err := FetchAndUpdateState(ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, testAppVersion)
-	if err != nil {
-		t.Fatalf("FetchAndUpdateState(%q, %q, %v, %v, %v): %v", sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, err)
-	}
-	if code == hasIncompletePackagesCode {
-		t.Errorf("FetchAndUpdateState(%q, %q, %v, %v, %v): hasIncompletePackages=true, want false",
-			sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString, http.StatusOK)
 
 	pkgFoo := sample.ModulePath + "/foo"
 	if _, err := testDB.LegacyGetPackage(ctx, pkgFoo, internal.UnknownModulePath, sample.VersionString); err != nil {
@@ -543,10 +356,7 @@
 		},
 	})
 	defer tearDown()
-	sourceClient := source.NewClient(sourceTimeout)
-	if _, err := FetchAndUpdateState(ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, testAppVersion); err != nil {
-		t.Fatalf("FetchAndUpdateState: %v", err)
-	}
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString, http.StatusOK)
 	pkg, err := testDB.LegacyGetPackage(ctx, sample.ModulePath, internal.UnknownModulePath, sample.VersionString)
 	if err != nil {
 		t.Fatal(err)
@@ -667,7 +477,7 @@
 
 var testProxyCommitTime = time.Date(2019, 1, 30, 0, 0, 0, 0, time.UTC)
 
-func TestFetchAndInsertModule(t *testing.T) {
+func TestFetchAndUpdateState(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 	defer cancel()
 
@@ -850,7 +660,7 @@
 				LegacyModuleInfo: internal.LegacyModuleInfo{
 					ModuleInfo: internal.ModuleInfo{
 						ModulePath:        "nonredistributable.mod/module",
-						Version:           "v1.0.0",
+						Version:           sample.VersionString,
 						CommitTime:        testProxyCommitTime,
 						VersionType:       "release",
 						SourceInfo:        nil,
@@ -1077,19 +887,87 @@
 	}
 }
 
-func TestFetchAndInsertModuleTimeout(t *testing.T) {
+func TestFetchAndUpdateState_Timeout(t *testing.T) {
 	defer postgres.ResetTestDB(testDB, t)
 
 	proxyClient, teardownProxy := proxy.SetupTestProxy(t, nil)
 	defer teardownProxy()
-	sourceClient := source.NewClient(sourceTimeout)
 
-	wantErrString := "deadline exceeded"
 	ctx, cancel := context.WithTimeout(context.Background(), 0)
 	defer cancel()
-	_, err := FetchAndUpdateState(ctx, sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, "appVersionLabel")
-	if err == nil || !strings.Contains(err.Error(), wantErrString) {
-		t.Fatalf("FetchAndUpdateState(%q, %q, %v, %v, %v) returned error %v, want error containing %q",
-			sample.ModulePath, sample.VersionString, proxyClient, sourceClient, testDB, err, wantErrString)
+	fetchAndCheckStatus(ctx, t, proxyClient, sample.ModulePath, sample.VersionString, http.StatusInternalServerError)
+}
+
+func fetchAndCheckStatus(ctx context.Context, t *testing.T, proxyClient *proxy.Client, modulePath, version string, wantCode int) {
+	t.Helper()
+	sourceClient := source.NewClient(sourceTimeout)
+	code, err := FetchAndUpdateState(ctx, modulePath, version, proxyClient, sourceClient, testDB, testAppVersion)
+	switch code {
+	case http.StatusOK:
+		if err != nil {
+			t.Fatalf("FetchAndUpdateState: %v", err)
+		}
+	case derrors.ToStatus(derrors.AlternativeModule):
+		if !errors.Is(err, derrors.AlternativeModule) {
+			t.Fatalf("FetchAndUpdateState: %v; want = %v", err, derrors.AlternativeModule)
+		}
+	case http.StatusNotFound:
+		if !errors.Is(err, derrors.NotFound) {
+			t.Fatalf("FetchAndUpdateState: %v; want = %v", err, derrors.NotFound)
+		}
+	case http.StatusForbidden:
+		if !errors.Is(err, derrors.Excluded) {
+			t.Fatalf("FetchAndUpdateState: %v; want = %v", err, derrors.NotFound)
+		}
+	case http.StatusInternalServerError:
+		// The only case where we check for a status 500 is in
+		// TestFetchAndUpdateState_Timeout.
+		wantErrString := "deadline exceeded"
+		if !strings.Contains(err.Error(), wantErrString) {
+			t.Fatalf("FetchAndUpdateState: %v; want error containing %q", err, wantErrString)
+		}
+		return
+	}
+	if code != wantCode {
+		t.Fatalf("got %d; want = %d", code, wantCode)
+	}
+
+	_, err = testDB.GetModuleInfo(ctx, modulePath, version)
+	switch code {
+	case http.StatusOK, hasIncompletePackagesCode:
+		if err != nil {
+			t.Fatalf("testDB.GetModuleInfo: %v", err)
+		}
+	default:
+		if !errors.Is(err, derrors.NotFound) {
+			t.Fatalf("got %v, want Is(NotFound)", err)
+		}
+	}
+	if semver.IsValid(version) {
+		if _, err := testDB.GetModuleVersionState(ctx, modulePath, version); err != nil {
+			t.Fatal(err)
+		}
+	}
+	vm, err := testDB.GetVersionMap(ctx, modulePath, version)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if vm.Status != wantCode {
+		t.Fatalf("testDB.GetVersionMap(ctx, %q, %q): status = %d, want = %d", modulePath, version, vm.Status, wantCode)
+	}
+}
+
+func checkPackageVersionStates(ctx context.Context, t *testing.T, modulePath, version string, wantStates []*internal.PackageVersionState) {
+	t.Helper()
+	gotStates, err := testDB.GetPackageVersionStatesForModule(ctx, modulePath, version)
+	if err != nil {
+		t.Fatal(err)
+	}
+	sort.Slice(gotStates, func(i, j int) bool {
+		return gotStates[i].PackagePath < gotStates[j].PackagePath
+	})
+	if diff := cmp.Diff(wantStates, gotStates); diff != "" {
+		t.Errorf("testDB.GetPackageVersionStatesForModule(ctx, %q, %q) mismatch (-want +got):\n%s",
+			modulePath, version, diff)
 	}
 }