internal/datasource: factor out bypassing the license check
For golang/go#47780
Change-Id: Ibd516c2dd69a0d1b483d1f3ecc26dde1507ad54e
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/344678
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/internal/datasource/datasource.go b/internal/datasource/datasource.go
index e4ce203..f2149be 100644
--- a/internal/datasource/datasource.go
+++ b/internal/datasource/datasource.go
@@ -24,9 +24,24 @@
// dataSource implements the internal.DataSource interface, by trying a list of
// fetch.ModuleGetters to fetch modules and caching the results.
type dataSource struct {
- getters []fetch.ModuleGetter
- sourceClient *source.Client
- cache *lru.Cache
+ getters []fetch.ModuleGetter
+ sourceClient *source.Client
+ bypassLicenseCheck bool
+ cache *lru.Cache
+}
+
+func newDataSource(getters []fetch.ModuleGetter, sc *source.Client, bypassLicenseCheck bool) *dataSource {
+ cache, err := lru.New(maxCachedModules)
+ if err != nil {
+ // Can only happen if size is bad.
+ panic(err)
+ }
+ return &dataSource{
+ getters: getters,
+ sourceClient: sc,
+ bypassLicenseCheck: bypassLicenseCheck,
+ cache: cache,
+ }
}
// cacheEntry holds a fetched module or an error, if the fetch failed.
@@ -37,19 +52,6 @@
const maxCachedModules = 100
-func newDataSource(getters []fetch.ModuleGetter, sc *source.Client) *dataSource {
- cache, err := lru.New(maxCachedModules)
- if err != nil {
- // Can only happen if size is bad.
- panic(err)
- }
- return &dataSource{
- getters: getters,
- sourceClient: sc,
- cache: cache,
- }
-}
-
// cacheGet returns information from the cache if it is present, and (nil, nil) otherwise.
func (ds *dataSource) cacheGet(path, version string) (*internal.Module, error) {
// Look for an exact match first, then use LocalVersion, as for a
@@ -80,7 +82,16 @@
fr := fetch.FetchModule(ctx, modulePath, version, g, ds.sourceClient)
defer fr.Defer()
if fr.Error == nil {
- return fr.Module, nil
+ m := fr.Module
+ if ds.bypassLicenseCheck {
+ m.IsRedistributable = true
+ for _, unit := range m.Units {
+ unit.IsRedistributable = true
+ }
+ } else {
+ m.RemoveNonRedistributableData()
+ }
+ return m, nil
}
if !errors.Is(fr.Error, derrors.NotFound) {
return nil, fr.Error
diff --git a/internal/datasource/datasource_test.go b/internal/datasource/datasource_test.go
index 34213be..8207fe8 100644
--- a/internal/datasource/datasource_test.go
+++ b/internal/datasource/datasource_test.go
@@ -13,7 +13,7 @@
)
func TestCache(t *testing.T) {
- ds := newDataSource(nil, nil)
+ ds := newDataSource(nil, nil, false)
m1 := &internal.Module{}
ds.cachePut("m1", fetch.LocalVersion, m1, nil)
ds.cachePut("m2", "v1.0.0", nil, derrors.NotFound)
diff --git a/internal/datasource/local.go b/internal/datasource/local.go
index b6fe31f..d0257c4 100644
--- a/internal/datasource/local.go
+++ b/internal/datasource/local.go
@@ -30,7 +30,7 @@
func NewLocal(getters []fetch.ModuleGetter, sc *source.Client) *LocalDataSource {
return &LocalDataSource{
sourceClient: sc,
- ds: newDataSource(getters, sc),
+ ds: newDataSource(getters, sc, true),
}
}
@@ -49,19 +49,7 @@
// fetch fetches a module using the configured ModuleGetters.
// It tries each getter in turn until it finds one that has the module.
func (ds *LocalDataSource) fetch(ctx context.Context, modulePath, version string) (_ *internal.Module, err error) {
- m, err := ds.ds.fetch(ctx, modulePath, version)
- if err != nil {
- return nil, err
- }
- adjust(m)
- return m, nil
-}
-
-func adjust(m *internal.Module) {
- m.IsRedistributable = true
- for _, unit := range m.Units {
- unit.IsRedistributable = true
- }
+ return ds.ds.fetch(ctx, modulePath, version)
}
// NewGOPATHModuleGetter returns a module getter that uses the GOPATH
diff --git a/internal/datasource/proxy.go b/internal/datasource/proxy.go
index 86fc33f..758c7cb 100644
--- a/internal/datasource/proxy.go
+++ b/internal/datasource/proxy.go
@@ -26,19 +26,18 @@
// New returns a new direct proxy datasource.
func NewProxy(proxyClient *proxy.Client) *ProxyDataSource {
- return newProxyDataSource(proxyClient, source.NewClient(1*time.Minute))
+ return newProxyDataSource(proxyClient, source.NewClient(1*time.Minute), false)
}
-func NewForTesting(proxyClient *proxy.Client) *ProxyDataSource {
- return newProxyDataSource(proxyClient, source.NewClientForTesting())
+func NewForTesting(proxyClient *proxy.Client, bypassLicenseCheck bool) *ProxyDataSource {
+ return newProxyDataSource(proxyClient, source.NewClientForTesting(), bypassLicenseCheck)
}
-func newProxyDataSource(proxyClient *proxy.Client, sourceClient *source.Client) *ProxyDataSource {
- ds := newDataSource([]fetch.ModuleGetter{fetch.NewProxyModuleGetter(proxyClient)}, sourceClient)
+func newProxyDataSource(proxyClient *proxy.Client, sourceClient *source.Client, bypassLicenseCheck bool) *ProxyDataSource {
+ ds := newDataSource([]fetch.ModuleGetter{fetch.NewProxyModuleGetter(proxyClient)}, sourceClient, bypassLicenseCheck)
return &ProxyDataSource{
- ds: ds,
- proxyClient: proxyClient,
- bypassLicenseCheck: false,
+ ds: ds,
+ proxyClient: proxyClient,
}
}
@@ -46,9 +45,7 @@
// license checks. That means all data will be returned for non-redistributable
// modules, packages and directories.
func NewBypassingLicenseCheck(c *proxy.Client) *ProxyDataSource {
- ds := NewProxy(c)
- ds.bypassLicenseCheck = true
- return ds
+ return newProxyDataSource(c, source.NewClient(1*time.Minute), true)
}
// ProxyDataSource implements the frontend.DataSource interface, by querying a
@@ -56,9 +53,8 @@
type ProxyDataSource struct {
proxyClient *proxy.Client
- mu sync.Mutex
- ds *dataSource
- bypassLicenseCheck bool
+ mu sync.Mutex
+ ds *dataSource
}
// getModule retrieves a version from the cache, or failing that queries and
@@ -76,15 +72,6 @@
m, err := ds.ds.fetch(ctx, modulePath, version)
if m != nil {
- if ds.bypassLicenseCheck {
- m.IsRedistributable = true
- for _, pkg := range m.Packages() {
- pkg.IsRedistributable = true
- }
- } else {
- m.RemoveNonRedistributableData()
- }
- //
// Use the go.mod file at the raw latest version to fill in deprecation
// and retraction information.
lmv, err2 := fetch.LatestModuleVersions(ctx, modulePath, ds.proxyClient, nil)
diff --git a/internal/datasource/proxy_test.go b/internal/datasource/proxy_test.go
index c16e3d5..fcf7a6d 100644
--- a/internal/datasource/proxy_test.go
+++ b/internal/datasource/proxy_test.go
@@ -34,11 +34,11 @@
os.Exit(m.Run())
}
-func setup(t *testing.T) (context.Context, *ProxyDataSource, func()) {
+func setup(t *testing.T, bypassLicenseCheck bool) (context.Context, *ProxyDataSource, func()) {
t.Helper()
client, teardownProxy := proxytest.SetupTestClient(t, testModules)
ctx, cancel := context.WithTimeout(context.Background(), 40*time.Second)
- return ctx, NewForTesting(client), func() {
+ return ctx, NewForTesting(client, bypassLicenseCheck), func() {
teardownProxy()
cancel()
}
@@ -73,7 +73,7 @@
)
func TestGetModuleInfo(t *testing.T) {
- ctx, ds, teardown := setup(t)
+ ctx, ds, teardown := setup(t, false)
defer teardown()
modinfo := func(m, v string) *internal.ModuleInfo {
@@ -125,7 +125,7 @@
}
func TestProxyGetUnitMeta(t *testing.T) {
- ctx, ds, teardown := setup(t)
+ ctx, ds, teardown := setup(t, false)
defer teardown()
for _, test := range []struct {
@@ -205,9 +205,8 @@
for _, bypass := range []bool{false, true} {
t.Run(fmt.Sprintf("bypass=%t", bypass), func(t *testing.T) {
// re-create the data source to get around caching
- ctx, ds, teardown := setup(t)
+ ctx, ds, teardown := setup(t, bypass)
defer teardown()
- ds.bypassLicenseCheck = bypass
for _, test := range []struct {
path string
wantEmpty bool
@@ -275,7 +274,7 @@
defer teardownProxy()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
- ds := NewForTesting(client)
+ ds := NewForTesting(client, false)
for _, test := range []struct {
fullPath string