internal/proxy: support Disable-Module-Fetch header in fetch

Add Client.GetInfoNoFetch, which sets the Disable-Module-Fetch
header to avoid fetching a module from the proxy if it isn't in
the proxy's cache. This reduces load on the proxy.

We return the new error code NotFetched instead of NotFound in this
case, so we can distinguish modules where this happened in the DB.

Change-Id: I2134d09e09b115e2ed59ba1a479ef20c1ebe4a7e
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/277812
Trust: Jonathan Amsterdam <jba@google.com>
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/derrors/derrors.go b/internal/derrors/derrors.go
index de0612c..4983381 100644
--- a/internal/derrors/derrors.go
+++ b/internal/derrors/derrors.go
@@ -21,6 +21,12 @@
 
 	// NotFound indicates that a requested entity was not found (HTTP 404).
 	NotFound = errors.New("not found")
+
+	// NotFetched means that the proxy returned "not found" with the
+	// Disable-Module-Fetch header set. We don't know if the module really
+	// doesn't exist, or the proxy just didn't fetch it.
+	NotFetched = errors.New("not fetched by proxy")
+
 	// InvalidArgument indicates that the input into the request is invalid in
 	// some way (HTTP 400).
 	InvalidArgument = errors.New("invalid argument")
@@ -107,6 +113,7 @@
 	// Since the following aren't HTTP statuses, pick unused codes.
 	{HasIncompletePackages, 290},
 	{DBModuleInsertInvalid, 480},
+	{NotFetched, 481},
 	{BadModule, 490},
 	{AlternativeModule, 491},
 	{ModuleTooLarge, 492},
diff --git a/internal/proxy/client.go b/internal/proxy/client.go
index a188161..4c2e7a9 100644
--- a/internal/proxy/client.go
+++ b/internal/proxy/client.go
@@ -42,6 +42,10 @@
 	Time    time.Time
 }
 
+// Setting this header to true prevents the proxy from fetching uncached
+// modules.
+const disableFetchHeader = "Disable-Module-Fetch"
+
 // New constructs a *Client using the provided url, which is expected to
 // be an absolute URI that can be directly passed to http.Get.
 func New(u string) (_ *Client, err error) {
@@ -55,8 +59,19 @@
 // GetInfo makes a request to $GOPROXY/<module>/@v/<requestedVersion>.info and
 // transforms that data into a *VersionInfo.
 func (c *Client) GetInfo(ctx context.Context, modulePath, requestedVersion string) (_ *VersionInfo, err error) {
+	return c.getInfo(ctx, modulePath, requestedVersion, false)
+}
+
+// GetInfoNoFetch behaves like GetInfo, except that it sets the
+// Disable-Module-Fetch header so that the proxy does not fetch a module it
+// doesn't already know about.
+func (c *Client) GetInfoNoFetch(ctx context.Context, modulePath, requestedVersion string) (_ *VersionInfo, err error) {
+	return c.getInfo(ctx, modulePath, requestedVersion, true)
+}
+
+func (c *Client) getInfo(ctx context.Context, modulePath, requestedVersion string, disableFetch bool) (_ *VersionInfo, err error) {
 	defer derrors.Wrap(&err, "proxy.Client.GetInfo(%q, %q)", modulePath, requestedVersion)
-	data, err := c.readBody(ctx, modulePath, requestedVersion, "info")
+	data, err := c.readBody(ctx, modulePath, requestedVersion, "info", disableFetch)
 	if err != nil {
 		return nil, err
 	}
@@ -70,7 +85,7 @@
 // GetMod makes a request to $GOPROXY/<module>/@v/<resolvedVersion>.mod and returns the raw data.
 func (c *Client) GetMod(ctx context.Context, modulePath, resolvedVersion string) (_ []byte, err error) {
 	defer derrors.Wrap(&err, "proxy.Client.GetMod(%q, %q)", modulePath, resolvedVersion)
-	return c.readBody(ctx, modulePath, resolvedVersion, "mod")
+	return c.readBody(ctx, modulePath, resolvedVersion, "mod", false)
 }
 
 // GetZip makes a request to $GOPROXY/<modulePath>/@v/<resolvedVersion>.zip and
@@ -81,7 +96,7 @@
 func (c *Client) GetZip(ctx context.Context, modulePath, resolvedVersion string) (_ *zip.Reader, err error) {
 	defer derrors.Wrap(&err, "proxy.Client.GetZip(ctx, %q, %q)", modulePath, resolvedVersion)
 
-	bodyBytes, err := c.readBody(ctx, modulePath, resolvedVersion, "zip")
+	bodyBytes, err := c.readBody(ctx, modulePath, resolvedVersion, "zip", false)
 	if err != nil {
 		return nil, err
 	}
@@ -106,7 +121,7 @@
 		return 0, fmt.Errorf("ctxhttp.Head(ctx, client, %q): %v", url, err)
 	}
 	defer res.Body.Close()
-	if err := responseError(res); err != nil {
+	if err := responseError(res, false); err != nil {
 		return 0, err
 	}
 	if res.ContentLength < 0 {
@@ -140,7 +155,7 @@
 	return fmt.Sprintf("%s/%s/@v/%s.%s", c.url, escapedPath, escapedVersion, suffix), nil
 }
 
-func (c *Client) readBody(ctx context.Context, modulePath, requestedVersion, suffix string) (_ []byte, err error) {
+func (c *Client) readBody(ctx context.Context, modulePath, requestedVersion, suffix string, disableFetch bool) (_ []byte, err error) {
 	defer derrors.Wrap(&err, "Client.readBody(%q, %q, %q)", modulePath, requestedVersion, suffix)
 
 	u, err := c.escapedURL(modulePath, requestedVersion, suffix)
@@ -148,7 +163,7 @@
 		return nil, err
 	}
 	var data []byte
-	err = c.executeRequest(ctx, u, func(body io.Reader) error {
+	err = c.executeRequest(ctx, u, disableFetch, func(body io.Reader) error {
 		var err error
 		data, err = ioutil.ReadAll(body)
 		return err
@@ -175,7 +190,7 @@
 		}
 		return scanner.Err()
 	}
-	if err := c.executeRequest(ctx, u, collect); err != nil {
+	if err := c.executeRequest(ctx, u, false, collect); err != nil {
 		return nil, err
 	}
 	return versions, nil
@@ -183,26 +198,34 @@
 
 // executeRequest executes an HTTP GET request for u, then calls the bodyFunc
 // on the response body, if no error occurred.
-func (c *Client) executeRequest(ctx context.Context, u string, bodyFunc func(body io.Reader) error) (err error) {
+func (c *Client) executeRequest(ctx context.Context, u string, disableFetch bool, bodyFunc func(body io.Reader) error) (err error) {
 	defer func() {
 		if ctx.Err() != nil {
 			err = fmt.Errorf("%v: %w", err, derrors.ProxyTimedOut)
 		}
 		derrors.Wrap(&err, "executeRequest(ctx, %q)", u)
 	}()
-	r, err := ctxhttp.Get(ctx, c.httpClient, u)
+
+	req, err := http.NewRequest("GET", u, nil)
 	if err != nil {
-		return fmt.Errorf("ctxhttp.Get(ctx, client, %q): %v", u, err)
+		return err
+	}
+	if disableFetch {
+		req.Header.Set(disableFetchHeader, "true")
+	}
+	r, err := ctxhttp.Do(ctx, c.httpClient, req)
+	if err != nil {
+		return fmt.Errorf("ctxhttp.Do(ctx, client, %q): %v", u, err)
 	}
 	defer r.Body.Close()
-	if err := responseError(r); err != nil {
+	if err := responseError(r, disableFetch); err != nil {
 		return err
 	}
 	return bodyFunc(r.Body)
 }
 
 // responseError translates the response status code to an appropriate error.
-func responseError(r *http.Response) error {
+func responseError(r *http.Response, fetchDisabled bool) error {
 	switch {
 	case 200 <= r.StatusCode && r.StatusCode < 300:
 		return nil
@@ -213,15 +236,23 @@
 		// If the response body contains "fetch timed out", treat this
 		// as a 504 response so that we retry fetching the module version again
 		// later.
+		//
+		// If the Disable-Module-Fetch header was set, use a different
+		// error code so we can tell the difference.
 		data, err := ioutil.ReadAll(r.Body)
 		if err != nil {
 			return fmt.Errorf("ioutil.readall: %v", err)
 		}
 		d := string(data)
-		if strings.Contains(d, "fetch timed out") {
-			return fmt.Errorf("%q: %w", d, derrors.ProxyTimedOut)
+		switch {
+		case strings.Contains(d, "fetch timed out"):
+			err = derrors.ProxyTimedOut
+		case fetchDisabled:
+			err = derrors.NotFetched
+		default:
+			err = derrors.NotFound
 		}
-		return fmt.Errorf("%q: %w", d, derrors.NotFound)
+		return fmt.Errorf("%q: %w", d, err)
 	default:
 		return fmt.Errorf("unexpected status %d %s", r.StatusCode, r.Status)
 	}
diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go
index 532823e..dcf3e4c 100644
--- a/internal/proxy/client_test.go
+++ b/internal/proxy/client_test.go
@@ -55,6 +55,14 @@
 	},
 }
 
+const uncachedModulePath = "example.com/uncached"
+
+var uncachedModule = &Module{
+	ModulePath: uncachedModulePath,
+	Version:    sample.VersionString,
+	NotCached:  true,
+}
+
 func TestGetLatestInfo(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 	defer cancel()
@@ -122,7 +130,7 @@
 	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 	defer cancel()
 
-	client, teardownProxy := SetupTestClient(t, []*Module{testModule})
+	client, teardownProxy := SetupTestClient(t, []*Module{testModule, uncachedModule})
 	defer teardownProxy()
 
 	info, err := client.GetInfo(ctx, sample.ModulePath, sample.VersionString)
@@ -138,6 +146,17 @@
 	if info.Time != expectedTime {
 		t.Errorf("VersionInfo.Time for GetInfo(ctx, %q, %q) = %v, want %v", sample.ModulePath, sample.VersionString, info.Time, expectedTime)
 	}
+
+	// GetInfoNoFetch returns "NotFetched" error on uncached module.
+	_, err = client.GetInfoNoFetch(ctx, uncachedModulePath, sample.VersionString)
+	if !errors.Is(err, derrors.NotFetched) {
+		t.Fatalf("got %v, want NotFetched", err)
+	}
+	// GetInfoNoFetch succeeds on cached module.
+	_, err = client.GetInfoNoFetch(ctx, sample.ModulePath, sample.VersionString)
+	if err != nil {
+		t.Fatal(err)
+	}
 }
 
 func TestGetInfo_Errors(t *testing.T) {
diff --git a/internal/proxy/server.go b/internal/proxy/server.go
index 51fffa1..2d43f19 100644
--- a/internal/proxy/server.go
+++ b/internal/proxy/server.go
@@ -29,6 +29,7 @@
 	ModulePath string
 	Version    string
 	Files      map[string]string
+	NotCached  bool // if true, behaves like it's uncached
 	zip        []byte
 }
 
@@ -45,9 +46,13 @@
 }
 
 // handleInfo creates an info endpoint for the specified module version.
-func (s *Server) handleInfo(modulePath, resolvedVersion string) {
+func (s *Server) handleInfo(modulePath, resolvedVersion string, uncached bool) {
 	urlPath := fmt.Sprintf("/%s/@v/%s.info", modulePath, resolvedVersion)
 	s.mux.HandleFunc(urlPath, func(w http.ResponseWriter, r *http.Request) {
+		if uncached && r.Header.Get(disableFetchHeader) == "true" {
+			http.Error(w, "not found: temporarily unavailable", http.StatusGone)
+			return
+		}
 		http.ServeContent(w, r, modulePath, time.Now(), defaultInfo(resolvedVersion))
 	})
 }
@@ -122,7 +127,7 @@
 		s.handleLatest(m.ModulePath, fmt.Sprintf("/%s/@v/master.info", m.ModulePath))
 		s.handleLatest(m.ModulePath, fmt.Sprintf("/%s/@v/main.info", m.ModulePath))
 	}
-	s.handleInfo(m.ModulePath, m.Version)
+	s.handleInfo(m.ModulePath, m.Version, m.NotCached)
 	s.handleMod(m)
 	s.handleZip(m)