internal/fetch: move load-shedding into FetchModule

The load-shedding logic can be made entirely inside FetchModule,
as long as the caller defers a function.

GetModuleInfo is no longer needed.

Change-Id: I1eacf96ebf06cfab57e184d8b4382a58c5b174ca
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/255978
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go
index beea630..8675847 100644
--- a/internal/fetch/fetch.go
+++ b/internal/fetch/fetch.go
@@ -17,12 +17,14 @@
 	"go/token"
 	"io"
 	"io/ioutil"
+	"math"
 	"net/http"
 	"os"
 	"path"
 	"runtime"
 	"runtime/debug"
 	"sort"
+	"strconv"
 	"strings"
 	"time"
 
@@ -46,53 +48,6 @@
 	errMalformedZip             = errors.New("module zip is malformed")
 )
 
-// ModuleInfo holds some basic, easy-to-get information about a module.
-type ModuleInfo struct {
-	ModulePath       string
-	RequestedVersion string
-	ResolvedVersion  string
-	CommitTime       time.Time
-	ZipSize          int64
-	Error            error
-}
-
-// GetModuleInfo returns preliminary information about a module, from the
-// proxy's .info and .zip endpoints (the latter via HEAD only, to avoid
-// downloading the zip).
-func GetModuleInfo(ctx context.Context, modulePath, requestedVersion string, proxyClient *proxy.Client) ModuleInfo {
-	mi := ModuleInfo{
-		ModulePath:       modulePath,
-		RequestedVersion: requestedVersion,
-	}
-	defer derrors.Wrap(&mi.Error, "GetModuleInfo(%q, %q)", modulePath, requestedVersion)
-
-	if modulePath == stdlib.ModulePath {
-		resolvedVersion, zipSize, err := stdlib.ZipInfo(requestedVersion)
-		if err != nil {
-			mi.Error = err
-			return mi
-		}
-		mi.ResolvedVersion = resolvedVersion
-		mi.ZipSize = zipSize
-		// CommitTime unknown
-		return mi
-	}
-	info, err := proxyClient.GetInfo(ctx, modulePath, requestedVersion)
-	if err != nil {
-		mi.Error = err
-		return mi
-	}
-	zipSize, err := proxyClient.GetZipSize(ctx, modulePath, info.Version)
-	if err != nil {
-		mi.Error = err
-		return mi
-	}
-	mi.ResolvedVersion = info.Version
-	mi.CommitTime = info.Time
-	mi.ZipSize = zipSize
-	return mi
-}
-
 type FetchResult struct {
 	ModulePath           string
 	RequestedVersion     string
@@ -100,52 +55,92 @@
 	GoModPath            string
 	Status               int
 	Error                error
+	Defer                func() // caller must defer this on all code paths
 	Module               *internal.Module
 	PackageVersionStates []*internal.PackageVersionState
 }
 
-// FetchModule queries the proxy or the Go repo using the ModuleInfo obtained
-// from GetModuleInfo. It then downloads the module zip, and processes the
-// contents to return an *internal.Module and related information.
+// FetchModule queries the proxy or the Go repo for the requested module
+// version, downloads the module zip, and processes the contents to return an
+// *internal.Module and related information.
 //
 // Even if err is non-nil, the result may contain useful information, like the go.mod path.
-func FetchModule(ctx context.Context, mi ModuleInfo, proxyClient *proxy.Client, sourceClient *source.Client) (fr *FetchResult) {
+//
+// Callers of FetchModule must
+//   defer fr.Defer()
+// immediately after the call.
+func FetchModule(ctx context.Context, modulePath, requestedVersion string, proxyClient *proxy.Client, sourceClient *source.Client) (fr *FetchResult) {
 	fr = &FetchResult{
-		ModulePath:       mi.ModulePath,
-		RequestedVersion: mi.RequestedVersion,
+		ModulePath:       modulePath,
+		RequestedVersion: requestedVersion,
+		Defer:            func() {},
 	}
 	defer func() {
 		if fr.Error != nil {
-			derrors.Wrap(&fr.Error, "FetchModule(%q, %q)", mi.ModulePath, mi.RequestedVersion)
+			derrors.Wrap(&fr.Error, "FetchModule(%q, %q)", modulePath, requestedVersion)
 			fr.Status = derrors.ToStatus(fr.Error)
 		}
 		if fr.Status == 0 {
 			fr.Status = http.StatusOK
 		}
-		log.Debugf(ctx, "memory after fetch of %s@%s: %dM", mi.ModulePath, mi.RequestedVersion, allocMeg())
+		log.Debugf(ctx, "memory after fetch of %s@%s: %dM", modulePath, requestedVersion, allocMeg())
 	}()
 
-	if mi.Error != nil {
-		fr.Error = mi.Error
-		return fr
-	}
-
 	var (
 		commitTime time.Time
 		zipReader  *zip.Reader
+		zipSize    int64
 		err        error
 	)
-	fr.ResolvedVersion = mi.ResolvedVersion
-	if mi.ModulePath == stdlib.ModulePath {
-		zipReader, commitTime, err = stdlib.Zip(mi.ResolvedVersion)
+	// Get the just information we need to make a load-shedding decision.
+	if modulePath == stdlib.ModulePath {
+		var resolvedVersion string
+		resolvedVersion, zipSize, err = stdlib.ZipInfo(requestedVersion)
+		if err != nil {
+			fr.Error = err
+			return fr
+		}
+		fr.ResolvedVersion = resolvedVersion
+	} else {
+		info, err := proxyClient.GetInfo(ctx, modulePath, requestedVersion)
+		if err != nil {
+			fr.Error = err
+			return fr
+		}
+		fr.ResolvedVersion = info.Version
+		commitTime = info.Time
+		zipSize, err = proxyClient.GetZipSize(ctx, modulePath, fr.ResolvedVersion)
+		if err != nil {
+			fr.Error = err
+			return fr
+		}
+	}
+
+	// Load shed or mark module as too large.
+	shouldShed, deferFunc := decideToShed(uint64(zipSize))
+	fr.Defer = deferFunc
+	if shouldShed {
+		fr.Error = derrors.SheddingLoad
+		return fr
+	}
+
+	if zipSize > maxModuleZipSize {
+		log.Warningf(ctx, "FetchModule: %s@%s zip size %dMi exceeds max %dMi",
+			modulePath, fr.ResolvedVersion, zipSize/mib, maxModuleZipSize/mib)
+		fr.Error = derrors.ModuleTooLarge
+		return fr
+	}
+
+	// Proceed with the fetch.
+	if modulePath == stdlib.ModulePath {
+		zipReader, commitTime, err = stdlib.Zip(requestedVersion)
 		if err != nil {
 			fr.Error = err
 			return fr
 		}
 		fr.GoModPath = stdlib.ModulePath
 	} else {
-		commitTime = mi.CommitTime
-		goModBytes, err := proxyClient.GetMod(ctx, mi.ModulePath, fr.ResolvedVersion)
+		goModBytes, err := proxyClient.GetMod(ctx, modulePath, fr.ResolvedVersion)
 		if err != nil {
 			fr.Error = err
 			return fr
@@ -156,27 +151,27 @@
 			return fr
 		}
 		fr.GoModPath = goModPath
-		if goModPath != mi.ModulePath {
+		if goModPath != modulePath {
 			// The module path in the go.mod file doesn't match the path of the
 			// zip file. Don't insert the module. Store an AlternativeModule
 			// status in module_version_states.
-			fr.Error = fmt.Errorf("module path=%s, go.mod path=%s: %w", mi.ModulePath, goModPath, derrors.AlternativeModule)
+			fr.Error = fmt.Errorf("module path=%s, go.mod path=%s: %w", modulePath, goModPath, derrors.AlternativeModule)
 			return fr
 		}
-		zipReader, err = proxyClient.GetZip(ctx, mi.ModulePath, fr.ResolvedVersion)
+		zipReader, err = proxyClient.GetZip(ctx, modulePath, fr.ResolvedVersion)
 		if err != nil {
 			fr.Error = err
 			return fr
 		}
 	}
-	mod, pvs, err := processZipFile(ctx, mi.ModulePath, fr.ResolvedVersion, commitTime, zipReader, sourceClient)
+	mod, pvs, err := processZipFile(ctx, modulePath, fr.ResolvedVersion, commitTime, zipReader, sourceClient)
 	if err != nil {
 		fr.Error = err
 		return fr
 	}
 	fr.Module = mod
 	fr.PackageVersionStates = pvs
-	if mi.ModulePath == stdlib.ModulePath {
+	if modulePath == stdlib.ModulePath {
 		fr.Module.HasGoMod = true
 	}
 	for _, state := range fr.PackageVersionStates {
@@ -792,3 +787,22 @@
 	runtime.ReadMemStats(&ms)
 	return int(ms.Alloc / (1024 * 1024))
 }
+
+// mib is the number of bytes in a mebibyte (Mi).
+const mib = 1024 * 1024
+
+// The largest module zip size we can comfortably process.
+// We probably will OOM if we process a module whose zip is larger.
+var maxModuleZipSize int64 = math.MaxInt64
+
+func init() {
+	m := os.Getenv("GO_DISCOVERY_MAX_MODULE_ZIP_MI")
+	if m != "" {
+		v, err := strconv.ParseInt(m, 10, 64)
+		if err != nil {
+			log.Errorf(context.Background(), "could not parse GO_DISCOVERY_MAX_MODULE_ZIP_MI value %q", v)
+		} else {
+			maxModuleZipSize = v * mib
+		}
+	}
+}
diff --git a/internal/fetch/fetch_test.go b/internal/fetch/fetch_test.go
index 2135ec3..66efa57 100644
--- a/internal/fetch/fetch_test.go
+++ b/internal/fetch/fetch_test.go
@@ -91,7 +91,8 @@
 				Files:      test.mod.mod.Files,
 			}})
 			defer teardownProxy()
-			got := FetchModule(ctx, GetModuleInfo(ctx, modulePath, fetchVersion, proxyClient), proxyClient, sourceClient)
+			got := FetchModule(ctx, modulePath, fetchVersion, proxyClient, sourceClient)
+			defer got.Defer()
 			if got.Error != nil {
 				t.Fatal(got.Error)
 			}
@@ -103,6 +104,7 @@
 				cmpopts.IgnoreFields(internal.LegacyPackage{}, "DocumentationHTML"),
 				cmpopts.IgnoreFields(internal.Documentation{}, "HTML"),
 				cmpopts.IgnoreFields(internal.PackageVersionState{}, "Error"),
+				cmpopts.IgnoreFields(FetchResult{}, "Defer"),
 				cmp.AllowUnexported(source.Info{}),
 				cmpopts.EquateEmpty(),
 			}
@@ -139,7 +141,8 @@
 			defer teardownProxy()
 
 			sourceClient := source.NewClient(sourceTimeout)
-			got := FetchModule(ctx, GetModuleInfo(ctx, modulePath, "v1.0.0", proxyClient), proxyClient, sourceClient)
+			got := FetchModule(ctx, modulePath, "v1.0.0", proxyClient, sourceClient)
+			defer got.Defer()
 			if !errors.Is(got.Error, test.wantErr) {
 				t.Fatalf("FetchModule(ctx, %q, v1.0.0, proxyClient, sourceClient): %v; wantErr = %v)", modulePath, got.Error, test.wantErr)
 			}
diff --git a/internal/worker/loadshedding.go b/internal/fetch/loadshedding.go
similarity index 91%
rename from internal/worker/loadshedding.go
rename to internal/fetch/loadshedding.go
index c56957c..9140b4c 100644
--- a/internal/worker/loadshedding.go
+++ b/internal/fetch/loadshedding.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package worker
+package fetch
 
 import (
 	"context"
@@ -72,7 +72,8 @@
 	}
 }
 
-type loadShedStats struct {
+// LoadShedStats holds statistics about load shedding.
+type LoadShedStats struct {
 	FetchesInFlight     int
 	ZipBytesInFlight    uint64
 	MaxZipBytesInFlight uint64
@@ -80,10 +81,11 @@
 	TotalRequests       int
 }
 
-func getLoadShedStats() loadShedStats {
+// GetLoadShedStats returns a snapshot of the current LoadShedStats.
+func GetLoadShedStats() LoadShedStats {
 	shedmu.Lock()
 	defer shedmu.Unlock()
-	return loadShedStats{
+	return LoadShedStats{
 		FetchesInFlight:     fetchesInFlight,
 		ZipBytesInFlight:    zipSizeInFlight,
 		MaxZipBytesInFlight: maxZipSizeInFlight,
diff --git a/internal/fetch/loadshedding_test.go b/internal/fetch/loadshedding_test.go
new file mode 100644
index 0000000..724e3c3
--- /dev/null
+++ b/internal/fetch/loadshedding_test.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fetch
+
+import "testing"
+
+func TestDecideToShed(t *testing.T) {
+	// By default (GO_DISCOVERY_MAX_IN_FLIGHT_ZIP_MI is unset), we should never decide to shed no matter the size of the zip.
+	got, d := decideToShed(1e10)
+	if want := false; got != want {
+		t.Fatalf("got %t, want %t", got, want)
+	}
+	d() // reset zipSizeInFlight
+	maxZipSizeInFlight = 10 * mib
+	got, d = decideToShed(3 * mib)
+	if want := false; got != want {
+		t.Fatalf("got %t, want %t", got, want)
+	}
+	bytesInFlight := func() int {
+		return int(GetLoadShedStats().ZipBytesInFlight)
+	}
+
+	if got, want := bytesInFlight(), 3*mib; got != want {
+		t.Fatalf("got %d, want %d", got, want)
+	}
+	got, _ = decideToShed(8 * mib) // 8 + 3 > 10; shed
+	if want := true; got != want {
+		t.Fatalf("got %t, want %t", got, want)
+	}
+	d() // should decrement zipSizeInFlight
+	if got, want := bytesInFlight(), 0; got != want {
+		t.Fatalf("got %d, want %d", got, want)
+	}
+	got, d = decideToShed(8 * mib) // 8 < 10; do not shed
+	if want := false; got != want {
+		t.Fatalf("got %t, want %t", got, want)
+	}
+	d()
+	if got, want := bytesInFlight(), 0; got != want {
+		t.Fatalf("got %d, want %d", got, want)
+	}
+}
diff --git a/internal/frontend/fetch.go b/internal/frontend/fetch.go
index 3916c78..8ad0958 100644
--- a/internal/frontend/fetch.go
+++ b/internal/frontend/fetch.go
@@ -494,8 +494,8 @@
 		derrors.Wrap(&err, "FetchAndUpdateState(%q, %q)", modulePath, requestedVersion)
 	}()
 
-	mi := fetch.GetModuleInfo(ctx, modulePath, requestedVersion, proxyClient)
-	fr := fetch.FetchModule(ctx, mi, proxyClient, sourceClient)
+	fr := fetch.FetchModule(ctx, modulePath, requestedVersion, proxyClient, sourceClient)
+	defer fr.Defer()
 	if fr.Error == nil {
 		// Only attempt to insert the module into module_version_states if the
 		// fetch process was successful.
diff --git a/internal/proxydatasource/datasource.go b/internal/proxydatasource/datasource.go
index 5b2faa7..ead449e 100644
--- a/internal/proxydatasource/datasource.go
+++ b/internal/proxydatasource/datasource.go
@@ -88,7 +88,8 @@
 	if e, ok := ds.versionCache[key]; ok {
 		return e.module, e.err
 	}
-	res := fetch.FetchModule(ctx, fetch.GetModuleInfo(ctx, modulePath, version, ds.proxyClient), ds.proxyClient, ds.sourceClient)
+	res := fetch.FetchModule(ctx, modulePath, version, ds.proxyClient, ds.sourceClient)
+	defer res.Defer()
 	m := res.Module
 	if m != nil {
 		if ds.bypassLicenseCheck {
diff --git a/internal/testing/integration/frontend_test.go b/internal/testing/integration/frontend_test.go
index fe9d9d0..d364e9e 100644
--- a/internal/testing/integration/frontend_test.go
+++ b/internal/testing/integration/frontend_test.go
@@ -222,8 +222,8 @@
 
 func fetchAndInsertModule(ctx context.Context, t *testing.T, tm *proxy.Module, proxyClient *proxy.Client) {
 	sourceClient := source.NewClient(1 * time.Second)
-	mi := fetch.GetModuleInfo(ctx, tm.ModulePath, tm.Version, proxyClient)
-	res := fetch.FetchModule(ctx, mi, proxyClient, sourceClient)
+	res := fetch.FetchModule(ctx, tm.ModulePath, tm.Version, proxyClient, sourceClient)
+	defer res.Defer()
 	if res.Error != nil {
 		t.Fatal(res.Error)
 	}
diff --git a/internal/worker/fetch.go b/internal/worker/fetch.go
index 95b5489..3cf29ea 100644
--- a/internal/worker/fetch.go
+++ b/internal/worker/fetch.go
@@ -7,11 +7,8 @@
 import (
 	"context"
 	"fmt"
-	"math"
 	"net/http"
-	"os"
 	"sort"
-	"strconv"
 	"strings"
 	"time"
 
@@ -115,25 +112,6 @@
 	return ft.Status, ft.Error
 }
 
-// mib is the number of bytes in a mebibyte (Mi).
-const mib = 1024 * 1024
-
-// The largest module zip size we can comfortably process.
-// We probably will OOM if we process a module whose zip is larger.
-var maxModuleZipSize int64 = math.MaxInt64
-
-func init() {
-	m := os.Getenv("GO_DISCOVERY_MAX_MODULE_ZIP_MI")
-	if m != "" {
-		v, err := strconv.ParseInt(m, 10, 64)
-		if err != nil {
-			log.Errorf(context.Background(), "could not parse GO_DISCOVERY_MAX_MODULE_ZIP_MI value %q", v)
-		} else {
-			maxModuleZipSize = v * mib
-		}
-	}
-}
-
 // fetchAndInsertModule fetches the given module version from the module proxy
 // or (in the case of the standard library) from the Go repo and writes the
 // resulting data to the database.
@@ -174,26 +152,11 @@
 	}
 
 	start := time.Now()
-	minfo := fetch.GetModuleInfo(ctx, modulePath, requestedVersion, proxyClient)
-	if minfo.Error == nil {
-		shouldShed, deferFunc := decideToShed(uint64(minfo.ZipSize))
-		defer deferFunc()
-		if shouldShed {
-			ft.Error = derrors.SheddingLoad
-			return ft
-		}
-	}
-	if minfo.Error == nil && minfo.ZipSize > maxModuleZipSize {
-		log.Warningf(ctx, "fetchAndInsertModule: %s@%s zip size %dMi exceeds max %dMi",
-			minfo.ModulePath, minfo.ResolvedVersion, minfo.ZipSize/mib, maxModuleZipSize/mib)
-		ft.Error = derrors.ModuleTooLarge
-		return ft
-	}
-
-	fr := fetch.FetchModule(ctx, minfo, proxyClient, sourceClient)
+	fr := fetch.FetchModule(ctx, modulePath, requestedVersion, proxyClient, sourceClient)
 	if fr == nil {
 		panic("fetch.FetchModule should never return a nil FetchResult")
 	}
+	defer fr.Defer()
 	ft.FetchResult = *fr
 	ft.timings["fetch.FetchModule"] = time.Since(start)
 	if ft.Error != nil {
diff --git a/internal/worker/fetch_test.go b/internal/worker/fetch_test.go
index 9393109..0aa7b37 100644
--- a/internal/worker/fetch_test.go
+++ b/internal/worker/fetch_test.go
@@ -993,40 +993,3 @@
 			um.Path, um.ModulePath, um.Version)
 	}
 }
-
-func TestDecideToShed(t *testing.T) {
-	// By default (GO_DISCOVERY_MAX_IN_FLIGHT_ZIP_MI is unset), we should never decide to shed no matter the size of the zip.
-	got, d := decideToShed(1e10)
-	if want := false; got != want {
-		t.Fatalf("got %t, want %t", got, want)
-	}
-	d() // reset zipSizeInFlight
-	maxZipSizeInFlight = 10 * mib
-	got, d = decideToShed(3 * mib)
-	if want := false; got != want {
-		t.Fatalf("got %t, want %t", got, want)
-	}
-	bytesInFlight := func() int {
-		return int(getLoadShedStats().ZipBytesInFlight)
-	}
-
-	if got, want := bytesInFlight(), 3*mib; got != want {
-		t.Fatalf("got %d, want %d", got, want)
-	}
-	got, _ = decideToShed(8 * mib) // 8 + 3 > 10; shed
-	if want := true; got != want {
-		t.Fatalf("got %t, want %t", got, want)
-	}
-	d() // should decrement zipSizeInFlight
-	if got, want := bytesInFlight(), 0; got != want {
-		t.Fatalf("got %d, want %d", got, want)
-	}
-	got, d = decideToShed(8 * mib) // 8 < 10; do not shed
-	if want := false; got != want {
-		t.Fatalf("got %t, want %t", got, want)
-	}
-	d()
-	if got, want := bytesInFlight(), 0; got != want {
-		t.Fatalf("got %d, want %d", got, want)
-	}
-}
diff --git a/internal/worker/pages.go b/internal/worker/pages.go
index 73f7619..76d97c8 100644
--- a/internal/worker/pages.go
+++ b/internal/worker/pages.go
@@ -23,6 +23,7 @@
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/config"
 	"golang.org/x/pkgsite/internal/derrors"
+	"golang.org/x/pkgsite/internal/fetch"
 	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/postgres"
 	"golang.org/x/sync/errgroup"
@@ -80,7 +81,7 @@
 		LocationID      string
 		Experiments     []*internal.Experiment
 		Excluded        []string
-		LoadShedStats   loadShedStats
+		LoadShedStats   fetch.LoadShedStats
 		GoMemStats      runtime.MemStats
 		ProcessStats    processMemStats
 		SystemStats     systemMemStats
@@ -91,7 +92,7 @@
 		LocationID:     s.cfg.LocationID,
 		Experiments:    experiments,
 		Excluded:       excluded,
-		LoadShedStats:  getLoadShedStats(),
+		LoadShedStats:  fetch.GetLoadShedStats(),
 		GoMemStats:     gms,
 		ProcessStats:   pms,
 		SystemStats:    sms,
diff --git a/internal/worker/server.go b/internal/worker/server.go
index 5959a48..54a2ba6 100644
--- a/internal/worker/server.go
+++ b/internal/worker/server.go
@@ -543,7 +543,7 @@
 
 // bytesToMi converts an integral value of bytes into mebibytes.
 func bytesToMi(b uint64) uint64 {
-	return b / mib
+	return b / (1024 * 1024)
 }
 
 // percentage computes the truncated percentage of x/y.