internal/proxy: cache info and mod endpoints

Add more caching to the proxy client so we can call Info and Mod
multiple times during a fetch without worrying about wasted RPCs to
the proxy.

This will enable moving the load shedder, which requires its own info
call, out of the fetch logic and into the worker.

For golang/go#48010

Change-Id: I4e875b1fd5b968aae174cfb93f4cf3a9a2b7a577
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/346729
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/internal/proxy/cache.go b/internal/proxy/cache.go
new file mode 100644
index 0000000..97c2f6d
--- /dev/null
+++ b/internal/proxy/cache.go
@@ -0,0 +1,89 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proxy
+
+import (
+	"archive/zip"
+	"sync"
+
+	"golang.org/x/pkgsite/internal"
+)
+
+// cache caches proxy info, mod and zip calls.
+type cache struct {
+	mu sync.Mutex
+
+	infoCache map[internal.Modver]*VersionInfo
+	modCache  map[internal.Modver][]byte
+
+	// One-element zip cache, to avoid a double download.
+	// See TestFetchAndUpdateStateCacheZip in internal/worker/fetch_test.go.
+	zipKey    internal.Modver
+	zipReader *zip.Reader
+}
+
+func (c *cache) getInfo(modulePath, version string) *VersionInfo {
+	if c == nil {
+		return nil
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	return c.infoCache[internal.Modver{Path: modulePath, Version: version}]
+}
+
+func (c *cache) putInfo(modulePath, version string, v *VersionInfo) {
+	if c == nil {
+		return
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.infoCache == nil {
+		c.infoCache = map[internal.Modver]*VersionInfo{}
+	}
+	c.infoCache[internal.Modver{Path: modulePath, Version: version}] = v
+}
+
+func (c *cache) getMod(modulePath, version string) []byte {
+	if c == nil {
+		return nil
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	return c.modCache[internal.Modver{Path: modulePath, Version: version}]
+}
+
+func (c *cache) putMod(modulePath, version string, b []byte) {
+	if c == nil {
+		return
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.modCache == nil {
+		c.modCache = map[internal.Modver][]byte{}
+	}
+	c.modCache[internal.Modver{Path: modulePath, Version: version}] = b
+}
+
+func (c *cache) getZip(modulePath, version string) *zip.Reader {
+	if c == nil {
+		return nil
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.zipKey == (internal.Modver{Path: modulePath, Version: version}) {
+		return c.zipReader
+	}
+	return nil
+}
+
+func (c *cache) putZip(modulePath, version string, r *zip.Reader) {
+	if c == nil {
+		return
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.zipKey = internal.Modver{Path: modulePath, Version: version}
+	c.zipReader = r
+}
diff --git a/internal/proxy/client.go b/internal/proxy/client.go
index 4240b15..17f4f7b 100644
--- a/internal/proxy/client.go
+++ b/internal/proxy/client.go
@@ -38,12 +38,7 @@
 	// Whether fetch should be disabled.
 	disableFetch bool
 
-	// One-element zip cache, to avoid a double download.
-	// See TestFetchAndUpdateStateCacheZip in internal/worker/fetch_test.go.
-	// Not thread-safe; should be used by only a single request goroutine.
-	rememberLastZip                   bool
-	lastZipModulePath, lastZipVersion string
-	lastZipReader                     *zip.Reader
+	cache *cache
 }
 
 // A VersionInfo contains metadata about a given version of a module.
@@ -81,14 +76,10 @@
 	return c.disableFetch
 }
 
-// WithZipCache returns a new client that caches the last zip
-// it downloads (not thread-safely).
-func (c *Client) WithZipCache() *Client {
+// WithCache returns a new client that caches some RPCs.
+func (c *Client) WithCache() *Client {
 	c2 := *c
-	c2.rememberLastZip = true
-	c2.lastZipModulePath = ""
-	c2.lastZipVersion = ""
-	c2.lastZipReader = nil
+	c2.cache = &cache{}
 	return &c2
 }
 
@@ -107,6 +98,10 @@
 		}
 		wrap(&err, "proxy.Client.Info(%q, %q)", modulePath, requestedVersion)
 	}()
+
+	if v := c.cache.getInfo(modulePath, requestedVersion); v != nil {
+		return v, nil
+	}
 	data, err := c.readBody(ctx, modulePath, requestedVersion, "info")
 	if err != nil {
 		return nil, err
@@ -115,13 +110,23 @@
 	if err := json.Unmarshal(data, &v); err != nil {
 		return nil, err
 	}
+	c.cache.putInfo(modulePath, requestedVersion, &v)
 	return &v, nil
 }
 
 // Mod makes a request to $GOPROXY/<module>/@v/<resolvedVersion>.mod and returns the raw data.
 func (c *Client) Mod(ctx context.Context, modulePath, resolvedVersion string) (_ []byte, err error) {
 	defer derrors.WrapStack(&err, "proxy.Client.Mod(%q, %q)", modulePath, resolvedVersion)
-	return c.readBody(ctx, modulePath, resolvedVersion, "mod")
+
+	if b := c.cache.getMod(modulePath, resolvedVersion); b != nil {
+		return b, nil
+	}
+	b, err := c.readBody(ctx, modulePath, resolvedVersion, "mod")
+	if err != nil {
+		return nil, err
+	}
+	c.cache.putMod(modulePath, resolvedVersion, b)
+	return b, nil
 }
 
 // Zip makes a request to $GOPROXY/<modulePath>/@v/<resolvedVersion>.zip and
@@ -132,8 +137,8 @@
 func (c *Client) Zip(ctx context.Context, modulePath, resolvedVersion string) (_ *zip.Reader, err error) {
 	defer derrors.WrapStack(&err, "proxy.Client.Zip(ctx, %q, %q)", modulePath, resolvedVersion)
 
-	if c.lastZipModulePath == modulePath && c.lastZipVersion == resolvedVersion {
-		return c.lastZipReader, nil
+	if r := c.cache.getZip(modulePath, resolvedVersion); r != nil {
+		return r, nil
 	}
 	bodyBytes, err := c.readBody(ctx, modulePath, resolvedVersion, "zip")
 	if err != nil {
@@ -143,11 +148,7 @@
 	if err != nil {
 		return nil, fmt.Errorf("zip.NewReader: %v: %w", err, derrors.BadModule)
 	}
-	if c.rememberLastZip {
-		c.lastZipModulePath = modulePath
-		c.lastZipVersion = resolvedVersion
-		c.lastZipReader = zipReader
-	}
+	c.cache.putZip(modulePath, resolvedVersion, zipReader)
 	return zipReader, nil
 }
 
diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go
index 7d84c3e..902f93a 100644
--- a/internal/proxy/client_test.go
+++ b/internal/proxy/client_test.go
@@ -354,3 +354,29 @@
 		}
 	}
 }
+
+func TestCache(t *testing.T) {
+	ctx := context.Background()
+	c1, teardownProxy := proxytest.SetupTestClient(t, []*proxytest.Module{testModule})
+
+	c := c1.WithCache()
+	got, err := c.Info(ctx, sample.ModulePath, sample.VersionString)
+	if err != nil {
+		t.Fatal(err)
+	}
+	_ = got
+	teardownProxy()
+	// Need server to satisfy different request.
+	_, err = c.Info(ctx, sample.ModulePath, "v4.5.6")
+	if err == nil {
+		t.Fatal("got nil, want error")
+	}
+	// Don't need server for cached request.
+	got2, err := c.Info(ctx, sample.ModulePath, sample.VersionString)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !cmp.Equal(got, got2) {
+		t.Errorf("got %+v first, then %+v", got, got2)
+	}
+}
diff --git a/internal/worker/fetch_test.go b/internal/worker/fetch_test.go
index de6f51e..19a6077 100644
--- a/internal/worker/fetch_test.go
+++ b/internal/worker/fetch_test.go
@@ -299,10 +299,10 @@
 	}
 
 	sourceClient := source.NewClient(sourceTimeout)
-	f := &Fetcher{proxyClient.WithZipCache(), sourceClient, testDB, nil}
 	for _, test := range testCases {
 		t.Run(strings.ReplaceAll(test.pkg+"@"+test.version, "/", " "), func(t *testing.T) {
 			defer postgres.ResetTestDB(testDB, t)
+			f := &Fetcher{proxyClient.WithCache(), sourceClient, testDB, nil}
 			if _, _, err := f.FetchAndUpdateState(ctx, test.modulePath, test.version, testAppVersion); err != nil {
 				t.Fatalf("FetchAndUpdateState(%q, %q, %v, %v, %v): %v", test.modulePath, test.version, proxyClient, sourceClient, testDB, err)
 			}
@@ -406,7 +406,7 @@
 
 	// With the cache, we download it only once.
 	postgres.ResetTestDB(testDB, t) // to avoid finding has_go_mod in the DB
-	f.ProxyClient = proxyClient.WithZipCache()
+	f.ProxyClient = proxyClient.WithCache()
 	if _, _, err := f.FetchAndUpdateState(ctx, "m.com", "v1.0.0", testAppVersion); err != nil {
 		t.Fatal(err)
 	}
diff --git a/internal/worker/server.go b/internal/worker/server.go
index 64a3c4f..03d7a8c 100644
--- a/internal/worker/server.go
+++ b/internal/worker/server.go
@@ -299,7 +299,7 @@
 	}
 
 	f := &Fetcher{
-		ProxyClient:  s.proxyClient.WithZipCache(),
+		ProxyClient:  s.proxyClient.WithCache(),
 		SourceClient: s.sourceClient,
 		DB:           s.db,
 		Cache:        s.cache,