internal/datasource: factor out getModule

Move the getModule method into the shared implementation.

Move the proxy client into the shared implementation, and set it to
nil for the local proxy.

Remove the lock around getModule. As the comment explains, the
resulting race is benign, and we gain the benefit of allowing multiple
goroutines to look up modules concurrently.

Remove the BuildContext arg to getModule. It's not appropriate there,
since getModule's job is to deliver all the information about a module.
We'll re-introduce it (and use it properly) later.

For golang/go#47780

Change-Id: I3e9440c0d6c1b24f7a190a516c9efac1ec0f05bd
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/344949
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/internal/datasource/datasource.go b/internal/datasource/datasource.go
index f2149be..c9a4888 100644
--- a/internal/datasource/datasource.go
+++ b/internal/datasource/datasource.go
@@ -18,6 +18,7 @@
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/fetch"
 	"golang.org/x/pkgsite/internal/log"
+	"golang.org/x/pkgsite/internal/proxy"
 	"golang.org/x/pkgsite/internal/source"
 )
 
@@ -28,9 +29,11 @@
 	sourceClient       *source.Client
 	bypassLicenseCheck bool
 	cache              *lru.Cache
+	prox               *proxy.Client // used for latest-version info only
+
 }
 
-func newDataSource(getters []fetch.ModuleGetter, sc *source.Client, bypassLicenseCheck bool) *dataSource {
+func newDataSource(getters []fetch.ModuleGetter, sc *source.Client, bypassLicenseCheck bool, prox *proxy.Client) *dataSource {
 	cache, err := lru.New(maxCachedModules)
 	if err != nil {
 		// Can only happen if size is bad.
@@ -41,6 +44,7 @@
 		sourceClient:       sc,
 		bypassLicenseCheck: bypassLicenseCheck,
 		cache:              cache,
+		prox:               prox,
 	}
 }
 
@@ -70,6 +74,38 @@
 	ds.cache.Add(internal.Modver{Path: path, Version: version}, cacheEntry{m, err})
 }
 
+// getModule gets the module at the given path and version. It first checks the
+// cache, and if it isn't there it then tries to fetch it.
+func (ds *dataSource) getModule(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
+	defer derrors.Wrap(&err, "getModule(%q, %q)", modulePath, version)
+
+	mod, err := ds.cacheGet(modulePath, version)
+	if mod != nil || err != nil {
+		return mod, err
+	}
+
+	// There can be a benign race here, where two goroutines both fetch the same
+	// module. At worst some work will be duplicated, but if that turns out to
+	// be a problem we could use golang.org/x/sync/singleflight.
+	m, err := ds.fetch(ctx, modulePath, version)
+	if m != nil && ds.prox != nil {
+		// Use the go.mod file at the raw latest version to fill in deprecation
+		// and retraction information.
+		lmv, err2 := fetch.LatestModuleVersions(ctx, modulePath, ds.prox, nil)
+		if err2 != nil {
+			err = err2
+		} else {
+			lmv.PopulateModuleInfo(&m.ModuleInfo)
+		}
+	}
+
+	// Don't cache cancellations.
+	if !errors.Is(err, context.Canceled) {
+		ds.cachePut(modulePath, version, m, err)
+	}
+	return m, err
+}
+
 // fetch fetches a module using the configured ModuleGetters.
 // It tries each getter in turn until it finds one that has the module.
 func (ds *dataSource) fetch(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
diff --git a/internal/datasource/datasource_test.go b/internal/datasource/datasource_test.go
index 8207fe8..88b8978 100644
--- a/internal/datasource/datasource_test.go
+++ b/internal/datasource/datasource_test.go
@@ -13,7 +13,7 @@
 )
 
 func TestCache(t *testing.T) {
-	ds := newDataSource(nil, nil, false)
+	ds := newDataSource(nil, nil, false, nil)
 	m1 := &internal.Module{}
 	ds.cachePut("m1", fetch.LocalVersion, m1, nil)
 	ds.cachePut("m2", "v1.0.0", nil, derrors.NotFound)
diff --git a/internal/datasource/local.go b/internal/datasource/local.go
index d0257c4..1b0a40e 100644
--- a/internal/datasource/local.go
+++ b/internal/datasource/local.go
@@ -30,28 +30,10 @@
 func NewLocal(getters []fetch.ModuleGetter, sc *source.Client) *LocalDataSource {
 	return &LocalDataSource{
 		sourceClient: sc,
-		ds:           newDataSource(getters, sc, true),
+		ds:           newDataSource(getters, sc, true, nil),
 	}
 }
 
-// getModule gets the module at the given path and version. It first checks the
-// cache, and if it isn't there it then tries to fetch it.
-func (ds *LocalDataSource) getModule(ctx context.Context, path, version string) (*internal.Module, error) {
-	m, err := ds.ds.cacheGet(path, version)
-	if m != nil || err != nil {
-		return m, err
-	}
-	m, err = ds.fetch(ctx, path, version)
-	ds.ds.cachePut(path, version, m, err)
-	return m, err
-}
-
-// fetch fetches a module using the configured ModuleGetters.
-// It tries each getter in turn until it finds one that has the module.
-func (ds *LocalDataSource) fetch(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
-	return ds.ds.fetch(ctx, modulePath, version)
-}
-
 // NewGOPATHModuleGetter returns a module getter that uses the GOPATH
 // environment variable to find the module with the given import path.
 func NewGOPATHModuleGetter(importPath string) (_ fetch.ModuleGetter, err error) {
@@ -84,7 +66,7 @@
 func (ds *LocalDataSource) GetUnit(ctx context.Context, pathInfo *internal.UnitMeta, fields internal.FieldSet, bc internal.BuildContext) (_ *internal.Unit, err error) {
 	defer derrors.Wrap(&err, "GetUnit(%q, %q)", pathInfo.Path, pathInfo.ModulePath)
 
-	module, err := ds.getModule(ctx, pathInfo.ModulePath, pathInfo.Version)
+	module, err := ds.ds.getModule(ctx, pathInfo.ModulePath, pathInfo.Version)
 	if err != nil {
 		return nil, err
 	}
@@ -126,11 +108,11 @@
 	defer derrors.Wrap(&err, "findModule(%q, %q, %q)", pkgPath, modulePath, version)
 
 	if modulePath != internal.UnknownModulePath {
-		return ds.getModule(ctx, modulePath, version)
+		return ds.ds.getModule(ctx, modulePath, version)
 	}
 	pkgPath = strings.TrimLeft(pkgPath, "/")
 	for _, modulePath := range internal.CandidateModulePaths(pkgPath) {
-		m, err := ds.getModule(ctx, modulePath, version)
+		m, err := ds.ds.getModule(ctx, modulePath, version)
 		if err == nil {
 			return m, nil
 		}
diff --git a/internal/datasource/proxy.go b/internal/datasource/proxy.go
index 758c7cb..5e34ccf 100644
--- a/internal/datasource/proxy.go
+++ b/internal/datasource/proxy.go
@@ -10,7 +10,6 @@
 	"fmt"
 	"strconv"
 	"strings"
-	"sync"
 	"time"
 
 	"golang.org/x/mod/semver"
@@ -34,10 +33,9 @@
 }
 
 func newProxyDataSource(proxyClient *proxy.Client, sourceClient *source.Client, bypassLicenseCheck bool) *ProxyDataSource {
-	ds := newDataSource([]fetch.ModuleGetter{fetch.NewProxyModuleGetter(proxyClient)}, sourceClient, bypassLicenseCheck)
+	ds := newDataSource([]fetch.ModuleGetter{fetch.NewProxyModuleGetter(proxyClient)}, sourceClient, bypassLicenseCheck, proxyClient)
 	return &ProxyDataSource{
-		ds:          ds,
-		proxyClient: proxyClient,
+		ds: ds,
 	}
 }
 
@@ -51,47 +49,9 @@
 // ProxyDataSource implements the frontend.DataSource interface, by querying a
 // module proxy directly and caching the results in memory.
 type ProxyDataSource struct {
-	proxyClient *proxy.Client
-
-	mu sync.Mutex
 	ds *dataSource
 }
 
-// getModule retrieves a version from the cache, or failing that queries and
-// processes the version from the proxy.
-func (ds *ProxyDataSource) getModule(ctx context.Context, modulePath, version string, _ internal.BuildContext) (_ *internal.Module, err error) {
-	defer derrors.Wrap(&err, "getModule(%q, %q)", modulePath, version)
-
-	ds.mu.Lock()
-	defer ds.mu.Unlock()
-
-	mod, err := ds.ds.cacheGet(modulePath, version)
-	if mod != nil || err != nil {
-		return mod, err
-	}
-
-	m, err := ds.ds.fetch(ctx, modulePath, version)
-	if m != nil {
-		// Use the go.mod file at the raw latest version to fill in deprecation
-		// and retraction information.
-		lmv, err2 := fetch.LatestModuleVersions(ctx, modulePath, ds.proxyClient, nil)
-		if err2 != nil {
-			err = err2
-		} else {
-			lmv.PopulateModuleInfo(&m.ModuleInfo)
-		}
-	}
-
-	if err != nil {
-		if !errors.Is(ctx.Err(), context.Canceled) {
-			ds.ds.cachePut(modulePath, version, m, err)
-		}
-		return nil, err
-	}
-	ds.ds.cachePut(modulePath, version, m, err)
-	return m, nil
-}
-
 // findModule finds the longest module path containing the given package path,
 // using the given finder func and iteratively testing parent directories of
 // the import path. It performs no testing as to whether the specified module
@@ -100,7 +60,7 @@
 	defer derrors.Wrap(&err, "findModule(%q, ...)", pkgPath)
 	pkgPath = strings.TrimLeft(pkgPath, "/")
 	for _, modulePath := range internal.CandidateModulePaths(pkgPath) {
-		info, err := ds.proxyClient.Info(ctx, modulePath, version)
+		info, err := ds.ds.prox.Info(ctx, modulePath, version)
 		if errors.Is(err, derrors.NotFound) {
 			continue
 		}
@@ -113,9 +73,9 @@
 }
 
 // getUnit returns information about a unit.
-func (ds *ProxyDataSource) getUnit(ctx context.Context, fullPath, modulePath, version string, bc internal.BuildContext) (_ *internal.Unit, err error) {
+func (ds *ProxyDataSource) getUnit(ctx context.Context, fullPath, modulePath, version string, _ internal.BuildContext) (_ *internal.Unit, err error) {
 	var m *internal.Module
-	m, err = ds.getModule(ctx, modulePath, version, bc)
+	m, err = ds.ds.getModule(ctx, modulePath, version)
 	if err != nil {
 		return nil, err
 	}
@@ -156,7 +116,7 @@
 func (ds *ProxyDataSource) getLatestMajorVersion(ctx context.Context, fullPath, modulePath string) (_ string, _ string, err error) {
 	// We are checking if the full path is valid so that we can forward the error if not.
 	seriesPath := internal.SeriesPathForModule(modulePath)
-	info, err := ds.proxyClient.Info(ctx, seriesPath, version.Latest)
+	info, err := ds.ds.prox.Info(ctx, seriesPath, version.Latest)
 	if err != nil {
 		return "", "", err
 	}
@@ -180,7 +140,7 @@
 	for v := startVersion; ; v++ {
 		query := fmt.Sprintf("%s/v%d", seriesPath, v)
 
-		_, err := ds.proxyClient.Info(ctx, query, version.Latest)
+		_, err := ds.ds.prox.Info(ctx, query, version.Latest)
 		if errors.Is(err, derrors.NotFound) {
 			if v == 2 {
 				return modulePath, fullPath, nil
diff --git a/internal/datasource/proxy_details.go b/internal/datasource/proxy_details.go
index 1bf07db..d528562 100644
--- a/internal/datasource/proxy_details.go
+++ b/internal/datasource/proxy_details.go
@@ -22,7 +22,7 @@
 // version specified by modulePath and version.
 func (ds *ProxyDataSource) GetModuleInfo(ctx context.Context, modulePath, version string) (_ *internal.ModuleInfo, err error) {
 	defer derrors.Wrap(&err, "GetModuleInfo(%q, %q)", modulePath, version)
-	m, err := ds.getModule(ctx, modulePath, version, internal.BuildContext{})
+	m, err := ds.ds.getModule(ctx, modulePath, version)
 	if err != nil {
 		return nil, err
 	}
@@ -41,7 +41,7 @@
 		}
 		inVersion = info.Version
 	}
-	m, err := ds.getModule(ctx, inModulePath, inVersion, internal.BuildContext{})
+	m, err := ds.ds.getModule(ctx, inModulePath, inVersion)
 	if err != nil {
 		return nil, err
 	}