internal/datasource: factor out bypassing the license check

For golang/go#47780

Change-Id: Ibd516c2dd69a0d1b483d1f3ecc26dde1507ad54e
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/344678
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/internal/datasource/datasource.go b/internal/datasource/datasource.go
index e4ce203..f2149be 100644
--- a/internal/datasource/datasource.go
+++ b/internal/datasource/datasource.go
@@ -24,9 +24,24 @@
 // dataSource implements the internal.DataSource interface, by trying a list of
 // fetch.ModuleGetters to fetch modules and caching the results.
 type dataSource struct {
-	getters      []fetch.ModuleGetter
-	sourceClient *source.Client
-	cache        *lru.Cache
+	getters            []fetch.ModuleGetter
+	sourceClient       *source.Client
+	bypassLicenseCheck bool
+	cache              *lru.Cache
+}
+
+func newDataSource(getters []fetch.ModuleGetter, sc *source.Client, bypassLicenseCheck bool) *dataSource {
+	cache, err := lru.New(maxCachedModules)
+	if err != nil {
+		// Can only happen if size is bad.
+		panic(err)
+	}
+	return &dataSource{
+		getters:            getters,
+		sourceClient:       sc,
+		bypassLicenseCheck: bypassLicenseCheck,
+		cache:              cache,
+	}
 }
 
 // cacheEntry holds a fetched module or an error, if the fetch failed.
@@ -37,19 +52,6 @@
 
 const maxCachedModules = 100
 
-func newDataSource(getters []fetch.ModuleGetter, sc *source.Client) *dataSource {
-	cache, err := lru.New(maxCachedModules)
-	if err != nil {
-		// Can only happen if size is bad.
-		panic(err)
-	}
-	return &dataSource{
-		getters:      getters,
-		sourceClient: sc,
-		cache:        cache,
-	}
-}
-
 // cacheGet returns information from the cache if it is present, and (nil, nil) otherwise.
 func (ds *dataSource) cacheGet(path, version string) (*internal.Module, error) {
 	// Look for an exact match first, then use LocalVersion, as for a
@@ -80,7 +82,16 @@
 		fr := fetch.FetchModule(ctx, modulePath, version, g, ds.sourceClient)
 		defer fr.Defer()
 		if fr.Error == nil {
-			return fr.Module, nil
+			m := fr.Module
+			if ds.bypassLicenseCheck {
+				m.IsRedistributable = true
+				for _, unit := range m.Units {
+					unit.IsRedistributable = true
+				}
+			} else {
+				m.RemoveNonRedistributableData()
+			}
+			return m, nil
 		}
 		if !errors.Is(fr.Error, derrors.NotFound) {
 			return nil, fr.Error
diff --git a/internal/datasource/datasource_test.go b/internal/datasource/datasource_test.go
index 34213be..8207fe8 100644
--- a/internal/datasource/datasource_test.go
+++ b/internal/datasource/datasource_test.go
@@ -13,7 +13,7 @@
 )
 
 func TestCache(t *testing.T) {
-	ds := newDataSource(nil, nil)
+	ds := newDataSource(nil, nil, false)
 	m1 := &internal.Module{}
 	ds.cachePut("m1", fetch.LocalVersion, m1, nil)
 	ds.cachePut("m2", "v1.0.0", nil, derrors.NotFound)
diff --git a/internal/datasource/local.go b/internal/datasource/local.go
index b6fe31f..d0257c4 100644
--- a/internal/datasource/local.go
+++ b/internal/datasource/local.go
@@ -30,7 +30,7 @@
 func NewLocal(getters []fetch.ModuleGetter, sc *source.Client) *LocalDataSource {
 	return &LocalDataSource{
 		sourceClient: sc,
-		ds:           newDataSource(getters, sc),
+		ds:           newDataSource(getters, sc, true),
 	}
 }
 
@@ -49,19 +49,7 @@
 // fetch fetches a module using the configured ModuleGetters.
 // It tries each getter in turn until it finds one that has the module.
 func (ds *LocalDataSource) fetch(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
-	m, err := ds.ds.fetch(ctx, modulePath, version)
-	if err != nil {
-		return nil, err
-	}
-	adjust(m)
-	return m, nil
-}
-
-func adjust(m *internal.Module) {
-	m.IsRedistributable = true
-	for _, unit := range m.Units {
-		unit.IsRedistributable = true
-	}
+	return ds.ds.fetch(ctx, modulePath, version)
 }
 
 // NewGOPATHModuleGetter returns a module getter that uses the GOPATH
diff --git a/internal/datasource/proxy.go b/internal/datasource/proxy.go
index 86fc33f..758c7cb 100644
--- a/internal/datasource/proxy.go
+++ b/internal/datasource/proxy.go
@@ -26,19 +26,18 @@
 
 // New returns a new direct proxy datasource.
 func NewProxy(proxyClient *proxy.Client) *ProxyDataSource {
-	return newProxyDataSource(proxyClient, source.NewClient(1*time.Minute))
+	return newProxyDataSource(proxyClient, source.NewClient(1*time.Minute), false)
 }
 
-func NewForTesting(proxyClient *proxy.Client) *ProxyDataSource {
-	return newProxyDataSource(proxyClient, source.NewClientForTesting())
+func NewForTesting(proxyClient *proxy.Client, bypassLicenseCheck bool) *ProxyDataSource {
+	return newProxyDataSource(proxyClient, source.NewClientForTesting(), bypassLicenseCheck)
 }
 
-func newProxyDataSource(proxyClient *proxy.Client, sourceClient *source.Client) *ProxyDataSource {
-	ds := newDataSource([]fetch.ModuleGetter{fetch.NewProxyModuleGetter(proxyClient)}, sourceClient)
+func newProxyDataSource(proxyClient *proxy.Client, sourceClient *source.Client, bypassLicenseCheck bool) *ProxyDataSource {
+	ds := newDataSource([]fetch.ModuleGetter{fetch.NewProxyModuleGetter(proxyClient)}, sourceClient, bypassLicenseCheck)
 	return &ProxyDataSource{
-		ds:                 ds,
-		proxyClient:        proxyClient,
-		bypassLicenseCheck: false,
+		ds:          ds,
+		proxyClient: proxyClient,
 	}
 }
 
@@ -46,9 +45,7 @@
 // license checks. That means all data will be returned for non-redistributable
 // modules, packages and directories.
 func NewBypassingLicenseCheck(c *proxy.Client) *ProxyDataSource {
-	ds := NewProxy(c)
-	ds.bypassLicenseCheck = true
-	return ds
+	return newProxyDataSource(c, source.NewClient(1*time.Minute), true)
 }
 
 // ProxyDataSource implements the frontend.DataSource interface, by querying a
@@ -56,9 +53,8 @@
 type ProxyDataSource struct {
 	proxyClient *proxy.Client
 
-	mu                 sync.Mutex
-	ds                 *dataSource
-	bypassLicenseCheck bool
+	mu sync.Mutex
+	ds *dataSource
 }
 
 // getModule retrieves a version from the cache, or failing that queries and
@@ -76,15 +72,6 @@
 
 	m, err := ds.ds.fetch(ctx, modulePath, version)
 	if m != nil {
-		if ds.bypassLicenseCheck {
-			m.IsRedistributable = true
-			for _, pkg := range m.Packages() {
-				pkg.IsRedistributable = true
-			}
-		} else {
-			m.RemoveNonRedistributableData()
-		}
-		//
 		// Use the go.mod file at the raw latest version to fill in deprecation
 		// and retraction information.
 		lmv, err2 := fetch.LatestModuleVersions(ctx, modulePath, ds.proxyClient, nil)
diff --git a/internal/datasource/proxy_test.go b/internal/datasource/proxy_test.go
index c16e3d5..fcf7a6d 100644
--- a/internal/datasource/proxy_test.go
+++ b/internal/datasource/proxy_test.go
@@ -34,11 +34,11 @@
 	os.Exit(m.Run())
 }
 
-func setup(t *testing.T) (context.Context, *ProxyDataSource, func()) {
+func setup(t *testing.T, bypassLicenseCheck bool) (context.Context, *ProxyDataSource, func()) {
 	t.Helper()
 	client, teardownProxy := proxytest.SetupTestClient(t, testModules)
 	ctx, cancel := context.WithTimeout(context.Background(), 40*time.Second)
-	return ctx, NewForTesting(client), func() {
+	return ctx, NewForTesting(client, bypassLicenseCheck), func() {
 		teardownProxy()
 		cancel()
 	}
@@ -73,7 +73,7 @@
 )
 
 func TestGetModuleInfo(t *testing.T) {
-	ctx, ds, teardown := setup(t)
+	ctx, ds, teardown := setup(t, false)
 	defer teardown()
 
 	modinfo := func(m, v string) *internal.ModuleInfo {
@@ -125,7 +125,7 @@
 }
 
 func TestProxyGetUnitMeta(t *testing.T) {
-	ctx, ds, teardown := setup(t)
+	ctx, ds, teardown := setup(t, false)
 	defer teardown()
 
 	for _, test := range []struct {
@@ -205,9 +205,8 @@
 	for _, bypass := range []bool{false, true} {
 		t.Run(fmt.Sprintf("bypass=%t", bypass), func(t *testing.T) {
 			// re-create the data source to get around caching
-			ctx, ds, teardown := setup(t)
+			ctx, ds, teardown := setup(t, bypass)
 			defer teardown()
-			ds.bypassLicenseCheck = bypass
 			for _, test := range []struct {
 				path      string
 				wantEmpty bool
@@ -275,7 +274,7 @@
 	defer teardownProxy()
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
-	ds := NewForTesting(client)
+	ds := NewForTesting(client, false)
 
 	for _, test := range []struct {
 		fullPath        string