internal/{worker/fetch}: include DB activity in load-shedding

Move the load-shedding logic to the worker and have it span both the
fetch and processing of the module (as previously) as well as
inserting it into the database.

This is a more accurate estimation of load, since running a lot of
concurrent queries definitely slows down processing.

Most of the time this won't make much difference, but under high load,
such as when processing multiple large modules, it will reduce DB
contention and should result in greater throughput.

For golang/go#48010

Change-Id: I7d0922e02d00182e867fd3b29fc284c32ecab5ee
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/346749
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/cmd/worker/main.go b/cmd/worker/main.go
index e26b127..feeb1d5 100644
--- a/cmd/worker/main.go
+++ b/cmd/worker/main.go
@@ -116,9 +116,9 @@
 		worker.ProcessingLag,
 		worker.UnprocessedModules,
 		worker.UnprocessedNewModules,
+		worker.SheddedFetchCount,
 		fetch.FetchLatencyDistribution,
 		fetch.FetchResponseCount,
-		fetch.SheddedFetchCount,
 		fetch.FetchPackageCount)
 	if err := dcensus.Init(cfg, views...); err != nil {
 		log.Fatal(ctx, err)
diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go
index 3833379..b7adfc0 100644
--- a/internal/fetch/fetch.go
+++ b/internal/fetch/fetch.go
@@ -41,11 +41,6 @@
 		"Latency of a fetch request.",
 		stats.UnitSeconds,
 	)
-	fetchesShedded = stats.Int64(
-		"go-discovery/worker/fetch-shedded",
-		"Count of shedded fetches.",
-		stats.UnitDimensionless,
-	)
 	fetchedPackages = stats.Int64(
 		"go-discovery/worker/fetch-package-count",
 		"Count of successfully fetched packages.",
@@ -76,13 +71,6 @@
 		Aggregation: view.Count(),
 		Description: "Count of packages successfully fetched",
 	}
-	// SheddedFetchCount counts the number of fetches that were shedded.
-	SheddedFetchCount = &view.View{
-		Name:        "go-discovery/worker/fetch-shedded",
-		Measure:     fetchesShedded,
-		Aggregation: view.Count(),
-		Description: "Count of shedded fetches",
-	}
 )
 
 type FetchResult struct {
@@ -98,7 +86,6 @@
 	GoModPath            string
 	Status               int
 	Error                error
-	Defer                func() // caller must defer this on all code paths
 	Module               *internal.Module
 	PackageVersionStates []*internal.PackageVersionState
 }
@@ -108,10 +95,6 @@
 // *internal.Module and related information.
 //
 // Even if err is non-nil, the result may contain useful information, like the go.mod path.
-//
-// Callers of FetchModule must
-//   defer fr.Defer()
-// immediately after the call.
 func FetchModule(ctx context.Context, modulePath, requestedVersion string, mg ModuleGetter, sourceClient *source.Client) (fr *FetchResult) {
 	start := time.Now()
 	defer func() {
@@ -125,7 +108,6 @@
 	fr = &FetchResult{
 		ModulePath:       modulePath,
 		RequestedVersion: requestedVersion,
-		Defer:            func() {},
 	}
 	defer derrors.Wrap(&fr.Error, "FetchModule(%q, %q)", modulePath, requestedVersion)
 
@@ -151,35 +133,11 @@
 	fr.ResolvedVersion = info.Version
 	commitTime := info.Time
 
-	var zipSize int64
-	if zipLoadShedder != nil {
-		var err error
-		zipSize, err = getZipSize(ctx, fr.ModulePath, fr.ResolvedVersion, mg)
-		if err != nil {
-			return nil, err
-		}
-		// Load shed or mark module as too large.
-		// We treat zip size as a proxy for the total memory consumed by
-		// processing a module, and use it to decide whether we can currently
-		// afford to process a module.
-		shouldShed, deferFunc := zipLoadShedder.shouldShed(uint64(zipSize))
-		fr.Defer = deferFunc
-		if shouldShed {
-			stats.Record(ctx, fetchesShedded.M(1))
-			return nil, fmt.Errorf("%w: size=%dMi", derrors.SheddingLoad, zipSize/mib)
-		}
-		if zipSize > maxModuleZipSize {
-			log.Warningf(ctx, "FetchModule: %s@%s zip size %dMi exceeds max %dMi",
-				fr.ModulePath, fr.ResolvedVersion, zipSize/mib, maxModuleZipSize/mib)
-			return nil, derrors.ModuleTooLarge
-		}
-	}
-
-	// Proceed with the fetch.
+	// TODO(golang/go#48010): move fetch info to the worker.
 	fi := &FetchInfo{
 		ModulePath: fr.ModulePath,
 		Version:    fr.ResolvedVersion,
-		ZipSize:    uint64(zipSize),
+		ZipSize:    uint64(0),
 		Start:      time.Now(),
 	}
 	startFetchInfo(fi)
@@ -268,13 +226,6 @@
 	return mg.Info(ctx, modulePath, requestedVersion)
 }
 
-func getZipSize(ctx context.Context, modulePath, resolvedVersion string, mg ModuleGetter) (_ int64, err error) {
-	if modulePath == stdlib.ModulePath {
-		return stdlib.EstimatedZipSize, nil
-	}
-	return mg.ZipSize(ctx, modulePath, resolvedVersion)
-}
-
 // getGoModPath returns the module path from the go.mod file, as well as the
 // contents of the file obtained from the module getter. If modulePath is the
 // standard library, then the contents will be nil.
diff --git a/internal/fetch/fetch_test.go b/internal/fetch/fetch_test.go
index dd260e3..cb6c5a1 100644
--- a/internal/fetch/fetch_test.go
+++ b/internal/fetch/fetch_test.go
@@ -118,7 +118,6 @@
 					defer cancel()
 
 					got, d := fetcher.fetch(t, true, ctx, mod, test.fetchVersion)
-					defer got.Defer()
 					if got.Error != nil {
 						t.Fatalf("fetching failed: %v", got.Error)
 					}
@@ -129,7 +128,6 @@
 					opts := []cmp.Option{
 						cmpopts.IgnoreFields(internal.Documentation{}, "Source"),
 						cmpopts.IgnoreFields(internal.PackageVersionState{}, "Error"),
-						cmpopts.IgnoreFields(FetchResult{}, "Defer"),
 						cmp.AllowUnexported(source.Info{}),
 						cmpopts.EquateEmpty(),
 					}
@@ -216,7 +214,6 @@
 		} {
 			t.Run(fmt.Sprintf("%s:%s", fetcher.name, test.name), func(t *testing.T) {
 				got, _ := fetcher.fetch(t, false, ctx, test.mod.mod, "")
-				defer got.Defer()
 				if !errors.Is(got.Error, test.wantErr) {
 					t.Fatalf("got error = %v; wantErr = %v)", got.Error, test.wantErr)
 				}
diff --git a/internal/fetch/load.go b/internal/fetch/load.go
index 8c47df7..ab3135b 100644
--- a/internal/fetch/load.go
+++ b/internal/fetch/load.go
@@ -17,7 +17,6 @@
 	"io"
 	"io/fs"
 	"io/ioutil"
-	"math"
 	"net/http"
 	"os"
 	"path"
@@ -26,10 +25,8 @@
 
 	"go.opencensus.io/trace"
 	"golang.org/x/pkgsite/internal"
-	"golang.org/x/pkgsite/internal/config"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/godoc"
-	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/source"
 	"golang.org/x/pkgsite/internal/stdlib"
 )
@@ -358,36 +355,3 @@
 	defer f.Close()
 	return ioutil.ReadAll(io.LimitReader(f, limit))
 }
-
-// mib is the number of bytes in a mebibyte (Mi).
-const mib = 1024 * 1024
-
-// The largest module zip size we can comfortably process.
-// We probably will OOM if we process a module whose zip is larger.
-var maxModuleZipSize int64 = math.MaxInt64
-
-func init() {
-	v := config.GetEnvInt(context.Background(), "GO_DISCOVERY_MAX_MODULE_ZIP_MI", -1)
-	if v > 0 {
-		maxModuleZipSize = int64(v) * mib
-	}
-}
-
-var zipLoadShedder *loadShedder
-
-func init() {
-	ctx := context.Background()
-	mebis := config.GetEnvInt(ctx, "GO_DISCOVERY_MAX_IN_FLIGHT_ZIP_MI", -1)
-	if mebis > 0 {
-		log.Infof(ctx, "shedding load over %dMi", mebis)
-		zipLoadShedder = &loadShedder{maxSizeInFlight: uint64(mebis) * mib}
-	}
-}
-
-// ZipLoadShedStats returns a snapshot of the current LoadShedStats for zip files.
-func ZipLoadShedStats() LoadShedStats {
-	if zipLoadShedder != nil {
-		return zipLoadShedder.stats()
-	}
-	return LoadShedStats{}
-}
diff --git a/internal/fetchdatasource/fetchdatasource.go b/internal/fetchdatasource/fetchdatasource.go
index 84b44c1..e7e38c4 100644
--- a/internal/fetchdatasource/fetchdatasource.go
+++ b/internal/fetchdatasource/fetchdatasource.go
@@ -128,7 +128,6 @@
 	}()
 	for _, g := range ds.opts.Getters {
 		fr := fetch.FetchModule(ctx, modulePath, version, g, ds.opts.SourceClient)
-		defer fr.Defer()
 		if fr.Error == nil {
 			m := fr.Module
 			if ds.opts.BypassLicenseCheck {
diff --git a/internal/frontend/fetch.go b/internal/frontend/fetch.go
index c468f61..94540e1 100644
--- a/internal/frontend/fetch.go
+++ b/internal/frontend/fetch.go
@@ -557,7 +557,6 @@
 	}()
 
 	fr := fetch.FetchModule(ctx, modulePath, requestedVersion, fetch.NewProxyModuleGetter(proxyClient), sourceClient)
-	defer fr.Defer()
 	if fr.Error == nil {
 		// Only attempt to insert the module into module_version_states if the
 		// fetch process was successful.
diff --git a/internal/testing/integration/frontend_test.go b/internal/testing/integration/frontend_test.go
index c64bbb6..ffeb978 100644
--- a/internal/testing/integration/frontend_test.go
+++ b/internal/testing/integration/frontend_test.go
@@ -100,7 +100,6 @@
 func fetchAndInsertModule(ctx context.Context, t *testing.T, tm *proxytest.Module, proxyClient *proxy.Client) {
 	sourceClient := source.NewClient(1 * time.Second)
 	res := fetch.FetchModule(ctx, tm.ModulePath, tm.Version, fetch.NewProxyModuleGetter(proxyClient), sourceClient)
-	defer res.Defer()
 	if res.Error != nil {
 		t.Fatal(res.Error)
 	}
diff --git a/internal/worker/fetch.go b/internal/worker/fetch.go
index d1cf479..5610ea4 100644
--- a/internal/worker/fetch.go
+++ b/internal/worker/fetch.go
@@ -8,6 +8,7 @@
 	"context"
 	"errors"
 	"fmt"
+	"math"
 	"net/http"
 	"sort"
 	"strings"
@@ -15,10 +16,13 @@
 	"time"
 	"unicode/utf8"
 
+	"go.opencensus.io/stats"
+	"go.opencensus.io/stats/view"
 	"go.opencensus.io/trace"
 	"golang.org/x/mod/semver"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/cache"
+	"golang.org/x/pkgsite/internal/config"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/experiment"
 	"golang.org/x/pkgsite/internal/fetch"
@@ -29,6 +33,22 @@
 	"golang.org/x/pkgsite/internal/stdlib"
 )
 
+var (
+	fetchesShedded = stats.Int64(
+		"go-discovery/worker/fetch-shedded",
+		"Count of shedded fetches.",
+		stats.UnitDimensionless,
+	)
+
+	// SheddedFetchCount counts the number of fetches that were shedded.
+	SheddedFetchCount = &view.View{
+		Name:        "go-discovery/worker/fetch-shedded",
+		Measure:     fetchesShedded,
+		Aggregation: view.Count(),
+		Description: "Count of shedded fetches",
+	}
+)
+
 // fetchTask represents the result of a fetch task that was processed.
 type fetchTask struct {
 	fetch.FetchResult
@@ -66,6 +86,13 @@
 		trace.StringAttribute("version", requestedVersion))
 	defer span.End()
 
+	// If we're overloaded, shed load by not processing this module.
+	deferFunc, err := f.maybeShed(ctx, modulePath, requestedVersion)
+	defer deferFunc()
+	if err != nil {
+		return derrors.ToStatus(err), "", err
+	}
+
 	// Begin by htting the proxy's info endpoint. That will make the proxy aware
 	// of the version if it isn't already, as can happen when we arrive here via
 	// frontend fetch. We ignore both the error and the information itself at
@@ -201,7 +228,6 @@
 		if fr == nil {
 			panic("fetch.FetchModule should never return a nil FetchResult")
 		}
-		defer fr.Defer()
 		ft.FetchResult = *fr
 		ft.timings["fetch.FetchModule"] = time.Since(start)
 	}()
@@ -418,3 +444,68 @@
 	}
 	return f.DB.UpdateLatestModuleVersions(ctx, lmv)
 }
+
+func (f *Fetcher) maybeShed(ctx context.Context, modulePath, version string) (func(), error) {
+	if zipLoadShedder == nil {
+		return func() {}, nil
+	}
+	zipSize, err := getZipSize(ctx, modulePath, version, f.ProxyClient)
+	if err != nil {
+		return func() {}, err
+	}
+	// Load shed or mark module as too large.
+	// We treat zip size as a proxy for the total memory consumed by
+	// processing a module, and use it to decide whether we can currently
+	// afford to process a module.
+	shouldShed, deferFunc := zipLoadShedder.shouldShed(uint64(zipSize))
+	if shouldShed {
+		stats.Record(ctx, fetchesShedded.M(1))
+		return deferFunc, fmt.Errorf("%w: size=%dMi", derrors.SheddingLoad, zipSize/mib)
+	}
+	if zipSize > maxModuleZipSize {
+		log.Warningf(ctx, "FetchModule: %s@%s zip size %dMi exceeds max %dMi",
+			modulePath, version, zipSize/mib, maxModuleZipSize/mib)
+		return deferFunc, derrors.ModuleTooLarge
+	}
+	return deferFunc, nil
+}
+
+func getZipSize(ctx context.Context, modulePath, resolvedVersion string, prox *proxy.Client) (_ int64, err error) {
+	if modulePath == stdlib.ModulePath {
+		return stdlib.EstimatedZipSize, nil
+	}
+	return prox.ZipSize(ctx, modulePath, resolvedVersion)
+}
+
+// mib is the number of bytes in a mebibyte (Mi).
+const mib = 1024 * 1024
+
+// The largest module zip size we can comfortably process.
+// We probably will OOM if we process a module whose zip is larger.
+var maxModuleZipSize int64 = math.MaxInt64
+
+func init() {
+	v := config.GetEnvInt(context.Background(), "GO_DISCOVERY_MAX_MODULE_ZIP_MI", -1)
+	if v > 0 {
+		maxModuleZipSize = int64(v) * mib
+	}
+}
+
+var zipLoadShedder *loadShedder
+
+func init() {
+	ctx := context.Background()
+	mebis := config.GetEnvInt(ctx, "GO_DISCOVERY_MAX_IN_FLIGHT_ZIP_MI", -1)
+	if mebis > 0 {
+		log.Infof(ctx, "shedding load over %dMi", mebis)
+		zipLoadShedder = &loadShedder{maxSizeInFlight: uint64(mebis) * mib}
+	}
+}
+
+// ZipLoadShedStats returns a snapshot of the current LoadShedStats for zip files.
+func ZipLoadShedStats() LoadShedStats {
+	if zipLoadShedder != nil {
+		return zipLoadShedder.stats()
+	}
+	return LoadShedStats{}
+}
diff --git a/internal/fetch/loadshedding.go b/internal/worker/loadshedding.go
similarity index 98%
rename from internal/fetch/loadshedding.go
rename to internal/worker/loadshedding.go
index 6f609c8..8e31519 100644
--- a/internal/fetch/loadshedding.go
+++ b/internal/worker/loadshedding.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package fetch
+package worker
 
 import (
 	"sync"
diff --git a/internal/fetch/loadshedding_test.go b/internal/worker/loadshedding_test.go
similarity index 98%
rename from internal/fetch/loadshedding_test.go
rename to internal/worker/loadshedding_test.go
index 0d84f32..970747c 100644
--- a/internal/fetch/loadshedding_test.go
+++ b/internal/worker/loadshedding_test.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package fetch
+package worker
 
 import (
 	"math"
diff --git a/internal/worker/pages.go b/internal/worker/pages.go
index 708c022..1c56ee7 100644
--- a/internal/worker/pages.go
+++ b/internal/worker/pages.go
@@ -94,7 +94,7 @@
 		StartTime       time.Time
 		Experiments     []*internal.Experiment
 		Excluded        []string
-		LoadShedStats   fetch.LoadShedStats
+		LoadShedStats   LoadShedStats
 		GoMemStats      runtime.MemStats
 		ProcessStats    processMemStats
 		SystemStats     systemMemStats
@@ -110,7 +110,7 @@
 		StartTime:      startTime,
 		Experiments:    experiments,
 		Excluded:       excluded,
-		LoadShedStats:  fetch.ZipLoadShedStats(),
+		LoadShedStats:  ZipLoadShedStats(),
 		GoMemStats:     gms,
 		ProcessStats:   pms,
 		SystemStats:    sms,