internal/datasource: factor out fetch

For golang/go#47780

Change-Id: I2133f43b4bd150f23f390807e4def8ee20f26d4d
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/344677
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/internal/datasource/datasource.go b/internal/datasource/datasource.go
index 634ebd0..e4ce203 100644
--- a/internal/datasource/datasource.go
+++ b/internal/datasource/datasource.go
@@ -8,9 +8,16 @@
 package datasource
 
 import (
+	"context"
+	"errors"
+	"fmt"
+	"time"
+
 	lru "github.com/hashicorp/golang-lru"
 	"golang.org/x/pkgsite/internal"
+	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/fetch"
+	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/source"
 )
 
@@ -60,3 +67,24 @@
 func (ds *dataSource) cachePut(path, version string, m *internal.Module, err error) {
 	ds.cache.Add(internal.Modver{Path: path, Version: version}, cacheEntry{m, err})
 }
+
+// fetch fetches a module using the configured ModuleGetters.
+// It tries each getter in turn until it finds one that has the module.
+func (ds *dataSource) fetch(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
+	log.Infof(ctx, "DataSource: fetching %s@%s", modulePath, version)
+	start := time.Now()
+	defer func() {
+		log.Infof(ctx, "DataSource: fetched %s@%s in %s with error %v", modulePath, version, time.Since(start), err)
+	}()
+	for _, g := range ds.getters {
+		fr := fetch.FetchModule(ctx, modulePath, version, g, ds.sourceClient)
+		defer fr.Defer()
+		if fr.Error == nil {
+			return fr.Module, nil
+		}
+		if !errors.Is(fr.Error, derrors.NotFound) {
+			return nil, fr.Error
+		}
+	}
+	return nil, fmt.Errorf("%s@%s: %w", modulePath, version, derrors.NotFound)
+}
diff --git a/internal/datasource/local.go b/internal/datasource/local.go
index 29cfb7a..b6fe31f 100644
--- a/internal/datasource/local.go
+++ b/internal/datasource/local.go
@@ -11,12 +11,10 @@
 	"os"
 	"path/filepath"
 	"strings"
-	"time"
 
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/fetch"
-	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/source"
 )
 
@@ -51,22 +49,12 @@
 // fetch fetches a module using the configured ModuleGetters.
 // It tries each getter in turn until it finds one that has the module.
 func (ds *LocalDataSource) fetch(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
-	log.Infof(ctx, "local DataSource: fetching %s@%s", modulePath, version)
-	start := time.Now()
-	defer func() {
-		log.Infof(ctx, "local DataSource: fetched %s@%s in %s with error %v", modulePath, version, time.Since(start), err)
-	}()
-	for _, g := range ds.ds.getters {
-		fr := fetch.FetchModule(ctx, modulePath, version, g, ds.sourceClient)
-		if fr.Error == nil {
-			adjust(fr.Module)
-			return fr.Module, nil
-		}
-		if !errors.Is(fr.Error, derrors.NotFound) {
-			return nil, fr.Error
-		}
+	m, err := ds.ds.fetch(ctx, modulePath, version)
+	if err != nil {
+		return nil, err
 	}
-	return nil, fmt.Errorf("%s@%s: %w", modulePath, version, derrors.NotFound)
+	adjust(m)
+	return m, nil
 }
 
 func adjust(m *internal.Module) {
diff --git a/internal/datasource/proxy.go b/internal/datasource/proxy.go
index 2398c8d..86fc33f 100644
--- a/internal/datasource/proxy.go
+++ b/internal/datasource/proxy.go
@@ -73,9 +73,8 @@
 	if mod != nil || err != nil {
 		return mod, err
 	}
-	res := fetch.FetchModule(ctx, modulePath, version, ds.ds.getters[0], ds.ds.sourceClient)
-	defer res.Defer()
-	m := res.Module
+
+	m, err := ds.ds.fetch(ctx, modulePath, version)
 	if m != nil {
 		if ds.bypassLicenseCheck {
 			m.IsRedistributable = true
@@ -88,22 +87,21 @@
 		//
 		// Use the go.mod file at the raw latest version to fill in deprecation
 		// and retraction information.
-		lmv, err := fetch.LatestModuleVersions(ctx, modulePath, ds.proxyClient, nil)
-		if err != nil {
-			res.Error = err
+		lmv, err2 := fetch.LatestModuleVersions(ctx, modulePath, ds.proxyClient, nil)
+		if err2 != nil {
+			err = err2
 		} else {
 			lmv.PopulateModuleInfo(&m.ModuleInfo)
 		}
 	}
 
-	if res.Error != nil {
+	if err != nil {
 		if !errors.Is(ctx.Err(), context.Canceled) {
-			ds.ds.cachePut(modulePath, version, m, res.Error)
+			ds.ds.cachePut(modulePath, version, m, err)
 		}
-		return nil, res.Error
+		return nil, err
 	}
 	ds.ds.cachePut(modulePath, version, m, err)
-
 	return m, nil
 }