internal/localdatasource: use getters
Instead of loading a list of modules initially, a local datasource now
takes a list of ModuleGetters, which are called on demand when a
module is requested.
For golang/go#47780
Change-Id: Ica3cc6d47de01ec78c451c0ef54b9ba0e0c5a96e
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/343591
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/cmd/pkgsite/main.go b/cmd/pkgsite/main.go
index c2147c6..ed03191 100644
--- a/cmd/pkgsite/main.go
+++ b/cmd/pkgsite/main.go
@@ -30,10 +30,12 @@
"github.com/google/safehtml/template"
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/dcensus"
+ "golang.org/x/pkgsite/internal/fetch"
"golang.org/x/pkgsite/internal/frontend"
"golang.org/x/pkgsite/internal/localdatasource"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/middleware"
+ "golang.org/x/pkgsite/internal/source"
)
const defaultAddr = "localhost:8080" // default webserver address
@@ -53,7 +55,7 @@
paths = "."
}
- lds := localdatasource.New()
+ lds := localdatasource.New(source.NewClient(time.Second))
dsg := func(context.Context) internal.DataSource { return lds }
server, err := frontend.NewServer(frontend.ServerConfig{
DataSourceGetter: dsg,
@@ -78,15 +80,20 @@
paths := strings.Split(pathList, ",")
loaded := len(paths)
for _, path := range paths {
- var err error
+ var (
+ mg fetch.ModuleGetter
+ err error
+ )
if *gopathMode {
- err = ds.LoadFromGOPATH(ctx, path)
+ mg, err = localdatasource.NewGOPATHModuleGetter(path)
} else {
- err = ds.Load(ctx, path)
+ mg, err = fetch.NewDirectoryModuleGetter("", path)
}
if err != nil {
log.Error(ctx, err)
loaded--
+ } else {
+ ds.AddModuleGetter(mg)
}
}
diff --git a/internal/localdatasource/datasource.go b/internal/localdatasource/datasource.go
index 69a0c6c..0188e34 100644
--- a/internal/localdatasource/datasource.go
+++ b/internal/localdatasource/datasource.go
@@ -9,13 +9,13 @@
import (
"context"
+ "errors"
"fmt"
"os"
"path"
"path/filepath"
"strings"
"sync"
- "time"
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/derrors"
@@ -29,54 +29,78 @@
sourceClient *source.Client
mu sync.Mutex
+ getters []fetch.ModuleGetter
loadedModules map[string]*internal.Module
}
// New creates and returns a new local datasource that bypasses license
// checks by default.
-func New() *DataSource {
+func New(sc *source.Client) *DataSource {
return &DataSource{
- sourceClient: source.NewClient(1 * time.Minute),
+ sourceClient: sc,
loadedModules: make(map[string]*internal.Module),
}
}
-// Load loads a module from the given local path. Loading is required before
-// being able to display the module.
-func (ds *DataSource) Load(ctx context.Context, localPath string) (err error) {
- defer derrors.Wrap(&err, "Load(%q)", localPath)
- return ds.fetch(ctx, "", localPath)
+// AddModuleGetter adds a module getter to the DataSource. To look up a module,
+// the getters are tried in the order they were added until the desired module
+// is found.
+func (ds *DataSource) AddModuleGetter(g fetch.ModuleGetter) {
+ ds.mu.Lock()
+ defer ds.mu.Unlock()
+ ds.getters = append(ds.getters, g)
}
-// LoadFromGOPATH loads a module from GOPATH using the given import path. The full
-// path of the module should be GOPATH/src/importPath. If several GOPATHs exist, the
-// module is loaded from the first one that contains the import path. Loading is required
-// before being able to display the module.
-func (ds *DataSource) LoadFromGOPATH(ctx context.Context, importPath string) (err error) {
- defer derrors.Wrap(&err, "LoadFromGOPATH(%q)", importPath)
-
- path := getFullPath(importPath)
- if path == "" {
- return fmt.Errorf("path %s doesn't exist: %w", importPath, derrors.NotFound)
+// getModule gets the module at the given path and version. It first checks the
+// cache, and if it isn't there it then tries to fetch it.
+func (ds *DataSource) getModule(ctx context.Context, path, version string) (*internal.Module, error) {
+ if m := ds.getFromCache(path, version); m != nil {
+ return m, nil
}
-
- return ds.fetch(ctx, importPath, path)
+ m, err := ds.fetch(ctx, path, version)
+ if err != nil {
+ return nil, err
+ }
+ ds.mu.Lock()
+ defer ds.mu.Unlock()
+ ds.loadedModules[m.ModulePath+"@"+m.Version] = m
+ return m, nil
}
-// fetch fetches a module using FetchLocalModule and adds it to the datasource.
-// If the fetching fails, an error is returned.
-func (ds *DataSource) fetch(ctx context.Context, modulePath, localPath string) error {
- fr := fetch.FetchLocalModule(ctx, modulePath, localPath, ds.sourceClient)
- if fr.Error != nil {
- return fr.Error
+// getFromCache returns a module from the cache if it is present, and nil otherwise.
+func (ds *DataSource) getFromCache(path, version string) *internal.Module {
+ ds.mu.Lock()
+ defer ds.mu.Unlock()
+ // Look for an exact match first.
+ if m := ds.loadedModules[path+"@"+version]; m != nil {
+ return m
}
+ // Look for the module path with LocalVersion, as for a directory-based or GOPATH-mode module.
+ return ds.loadedModules[path+"@"+fetch.LocalVersion]
+}
- fr.Module.IsRedistributable = true
- for _, unit := range fr.Module.Units {
+// fetch fetches a module using the configured ModuleGetters.
+// It tries each getter in turn until it finds one that has the module.
+func (ds *DataSource) fetch(ctx context.Context, modulePath, version string) (*internal.Module, error) {
+ for _, g := range ds.getters {
+ fr := fetch.FetchModule(ctx, modulePath, version, g, ds.sourceClient)
+ if fr.Error == nil {
+ adjust(fr.Module)
+ return fr.Module, nil
+ }
+ if !errors.Is(fr.Error, derrors.NotFound) {
+ return nil, fr.Error
+ }
+ }
+ return nil, fmt.Errorf("%s@%s: %w", modulePath, version, derrors.NotFound)
+}
+
+func adjust(m *internal.Module) {
+ m.IsRedistributable = true
+ for _, unit := range m.Units {
unit.IsRedistributable = true
}
-
- for _, unit := range fr.Module.Units {
+ for _, unit := range m.Units {
for _, d := range unit.Documentation {
unit.BuildContexts = append(unit.BuildContexts, internal.BuildContext{
GOOS: d.GOOS,
@@ -84,11 +108,18 @@
})
}
}
+}
- ds.mu.Lock()
- defer ds.mu.Unlock()
- ds.loadedModules[fr.ModulePath] = fr.Module
- return nil
+// NewGOPATHModuleGetter returns a module getter that uses the GOPATH
+// environment variable to find the module with the given import path.
+func NewGOPATHModuleGetter(importPath string) (_ fetch.ModuleGetter, err error) {
+ defer derrors.Wrap(&err, "NewGOPATHModuleGetter(%q)", importPath)
+
+ dir := getFullPath(importPath)
+ if dir == "" {
+ return nil, fmt.Errorf("path %s doesn't exist: %w", importPath, derrors.NotFound)
+ }
+ return fetch.NewDirectoryModuleGetter(importPath, dir)
}
// getFullPath takes an import path, tests it relative to each GOPATH, and returns
@@ -111,51 +142,30 @@
func (ds *DataSource) GetUnit(ctx context.Context, pathInfo *internal.UnitMeta, fields internal.FieldSet, bc internal.BuildContext) (_ *internal.Unit, err error) {
defer derrors.Wrap(&err, "GetUnit(%q, %q)", pathInfo.Path, pathInfo.ModulePath)
- modulepath := pathInfo.ModulePath
- path := pathInfo.Path
-
- ds.mu.Lock()
- defer ds.mu.Unlock()
- if ds.loadedModules[modulepath] == nil {
- return nil, fmt.Errorf("%s not loaded: %w", modulepath, derrors.NotFound)
+ module, err := ds.getModule(ctx, pathInfo.ModulePath, pathInfo.Version)
+ if err != nil {
+ return nil, err
}
-
- module := ds.loadedModules[modulepath]
for _, unit := range module.Units {
- if unit.Path == path {
+ if unit.Path == pathInfo.Path {
return unit, nil
}
}
- return nil, fmt.Errorf("%s not found: %w", path, derrors.NotFound)
+ return nil, fmt.Errorf("import path %s not found in module %s: %w", pathInfo.Path, pathInfo.ModulePath, derrors.NotFound)
}
// GetUnitMeta returns information about a path.
func (ds *DataSource) GetUnitMeta(ctx context.Context, path, requestedModulePath, requestedVersion string) (_ *internal.UnitMeta, err error) {
defer derrors.Wrap(&err, "GetUnitMeta(%q, %q, %q)", path, requestedModulePath, requestedVersion)
- if requestedModulePath == internal.UnknownModulePath {
- requestedModulePath, err = ds.findModule(path)
- if err != nil {
- return nil, err
- }
+ module, err := ds.findModule(ctx, path, requestedModulePath, requestedVersion)
+ if err != nil {
+ return nil, err
}
-
- ds.mu.Lock()
- module := ds.loadedModules[requestedModulePath]
- ds.mu.Unlock()
- if module == nil {
- return nil, fmt.Errorf("%s not loaded: %w", requestedModulePath, derrors.NotFound)
- }
-
um := &internal.UnitMeta{
- Path: path,
- ModuleInfo: internal.ModuleInfo{
- ModulePath: requestedModulePath,
- Version: fetch.LocalVersion,
- CommitTime: fetch.LocalCommitTime,
- IsRedistributable: module.IsRedistributable,
- },
+ Path: path,
+ ModuleInfo: module.ModuleInfo,
}
for _, u := range module.Units {
@@ -168,23 +178,26 @@
return um, nil
}
-// findModule finds the longest module path in loadedModules containing the given
-// package path. It iteratively checks parent directories to find an import path.
-// Returns an error if no module is found.
-func (ds *DataSource) findModule(pkgPath string) (_ string, err error) {
+// findModule finds the module with longest module path containing the given
+// package path. It returns an error if no module is found.
+func (ds *DataSource) findModule(ctx context.Context, pkgPath, modulePath, version string) (_ *internal.Module, err error) {
defer derrors.Wrap(&err, "findModule(%q)", pkgPath)
- pkgPath = strings.TrimLeft(pkgPath, "/")
-
- ds.mu.Lock()
- defer ds.mu.Unlock()
- for modulePath := pkgPath; modulePath != "" && modulePath != "."; modulePath = path.Dir(modulePath) {
- if ds.loadedModules[modulePath] != nil {
- return modulePath, nil
- }
+ if modulePath != internal.UnknownModulePath {
+ return ds.getModule(ctx, modulePath, version)
}
- return "", fmt.Errorf("%s not loaded: %w", pkgPath, derrors.NotFound)
+ pkgPath = strings.TrimLeft(pkgPath, "/")
+ for modulePath := pkgPath; modulePath != "" && modulePath != "."; modulePath = path.Dir(modulePath) {
+ m, err := ds.getModule(ctx, modulePath, version)
+ if err == nil {
+ return m, nil
+ }
+ if !errors.Is(err, derrors.NotFound) {
+ return nil, err
+ }
+ }
+ return nil, fmt.Errorf("could not find module for import path %s: %w", pkgPath, derrors.NotFound)
}
// GetLatestInfo is not implemented.
diff --git a/internal/localdatasource/datasource_test.go b/internal/localdatasource/datasource_test.go
index b83bb48..0a96178 100644
--- a/internal/localdatasource/datasource_test.go
+++ b/internal/localdatasource/datasource_test.go
@@ -7,6 +7,7 @@
import (
"context"
"errors"
+ "log"
"os"
"testing"
"time"
@@ -18,23 +19,18 @@
"golang.org/x/pkgsite/internal/fetch"
"golang.org/x/pkgsite/internal/godoc/dochtml"
"golang.org/x/pkgsite/internal/licenses"
+ "golang.org/x/pkgsite/internal/source"
"golang.org/x/pkgsite/internal/stdlib"
"golang.org/x/pkgsite/internal/testing/testhelper"
)
-var (
- ctx context.Context
- cancel func()
- datasource *DataSource
-)
+func TestMain(m *testing.M) {
+ os.Exit(run(m))
+}
-func setup(t *testing.T) (context.Context, func(), *DataSource, error) {
- t.Helper()
+var datasource *DataSource
- // Setup only once.
- if datasource != nil {
- return ctx, cancel, datasource, nil
- }
+func run(m *testing.M) int {
licenses.OmitExceptions = true
modules := []map[string]string{
{
@@ -82,31 +78,28 @@
}
dochtml.LoadTemplates(template.TrustedSourceFromConstant("../../static/doc"))
- datasource = New()
- ctx, cancel = context.WithTimeout(context.Background(), 20*time.Second)
+ datasource = New(source.NewClientForTesting())
for _, module := range modules {
directory, err := testhelper.CreateTestDirectory(module)
if err != nil {
- return ctx, func() { cancel() }, nil, err
+ log.Fatal(err)
}
defer os.RemoveAll(directory)
- err = datasource.Load(ctx, directory)
+ mg, err := fetch.NewDirectoryModuleGetter("", directory)
if err != nil {
- return ctx, func() { cancel() }, nil, err
+ log.Fatal(err)
}
+ datasource.AddModuleGetter(mg)
}
-
- return ctx, func() { cancel() }, datasource, nil
+ return m.Run()
}
func TestGetUnitMeta(t *testing.T) {
- ctx, cancel, ds, err := setup(t)
- if err != nil {
- t.Fatalf("setup failed: %s", err.Error())
- }
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
+ sourceInfo := source.NewGitHubInfo("https://github.com/my/module", "", "v0.0.0")
for _, test := range []struct {
path, modulePath string
want *internal.UnitMeta
@@ -122,6 +115,8 @@
Version: fetch.LocalVersion,
CommitTime: fetch.LocalCommitTime,
IsRedistributable: true,
+ HasGoMod: true,
+ SourceInfo: sourceInfo,
},
IsRedistributable: true,
},
@@ -137,6 +132,8 @@
Version: fetch.LocalVersion,
CommitTime: fetch.LocalCommitTime,
IsRedistributable: true,
+ HasGoMod: true,
+ SourceInfo: sourceInfo,
},
IsRedistributable: true,
},
@@ -152,6 +149,8 @@
IsRedistributable: true,
Version: fetch.LocalVersion,
CommitTime: fetch.LocalCommitTime,
+ HasGoMod: true,
+ SourceInfo: sourceInfo,
},
IsRedistributable: true,
},
@@ -168,6 +167,8 @@
Version: fetch.LocalVersion,
CommitTime: fetch.LocalCommitTime,
IsRedistributable: true,
+ HasGoMod: true,
+ SourceInfo: sourceInfo,
},
},
},
@@ -179,11 +180,11 @@
{
path: "net/http",
modulePath: stdlib.ModulePath,
- wantErr: derrors.NotFound,
+ wantErr: derrors.InvalidArgument,
},
} {
t.Run(test.path, func(t *testing.T) {
- got, err := ds.GetUnitMeta(ctx, test.path, test.modulePath, fetch.LocalVersion)
+ got, err := datasource.GetUnitMeta(ctx, test.path, test.modulePath, fetch.LocalVersion)
if test.wantErr != nil {
if !errors.Is(err, test.wantErr) {
t.Errorf("GetUnitMeta(%q, %q): %v; wantErr = %v)", test.path, test.modulePath, err, test.wantErr)
@@ -192,7 +193,7 @@
if err != nil {
t.Fatal(err)
}
- if diff := cmp.Diff(test.want, got); diff != "" {
+ if diff := cmp.Diff(test.want, got, cmp.AllowUnexported(source.Info{})); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
@@ -205,10 +206,7 @@
// This is a simple test to verify that data is fetched correctly. The
// return value of FetchResult is tested in internal/fetch so no need
// to repeat it.
- ctx, cancel, ds, err := setup(t)
- if err != nil {
- t.Fatalf("setup failed: %s", err.Error())
- }
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
for _, test := range []struct {
@@ -240,7 +238,7 @@
Path: test.path,
ModuleInfo: internal.ModuleInfo{ModulePath: test.modulePath},
}
- got, err := ds.GetUnit(ctx, um, 0, internal.BuildContext{})
+ got, err := datasource.GetUnit(ctx, um, 0, internal.BuildContext{})
if !test.wantLoaded {
if err == nil {
t.Fatalf("returned not loaded module %q", test.path)