internal/localdatasource: use getters

Instead of loading a list of modules initially, a local datasource now
takes a list of ModuleGetters, which are called on demand when a
module is requested.

For golang/go#47780

Change-Id: Ica3cc6d47de01ec78c451c0ef54b9ba0e0c5a96e
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/343591
Trust: Jonathan Amsterdam <jba@google.com>
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Julie Qiu <julie@golang.org>
diff --git a/cmd/pkgsite/main.go b/cmd/pkgsite/main.go
index c2147c6..ed03191 100644
--- a/cmd/pkgsite/main.go
+++ b/cmd/pkgsite/main.go
@@ -30,10 +30,12 @@
 	"github.com/google/safehtml/template"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/dcensus"
+	"golang.org/x/pkgsite/internal/fetch"
 	"golang.org/x/pkgsite/internal/frontend"
 	"golang.org/x/pkgsite/internal/localdatasource"
 	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/middleware"
+	"golang.org/x/pkgsite/internal/source"
 )
 
 const defaultAddr = "localhost:8080" // default webserver address
@@ -53,7 +55,7 @@
 		paths = "."
 	}
 
-	lds := localdatasource.New()
+	lds := localdatasource.New(source.NewClient(time.Second))
 	dsg := func(context.Context) internal.DataSource { return lds }
 	server, err := frontend.NewServer(frontend.ServerConfig{
 		DataSourceGetter: dsg,
@@ -78,15 +80,20 @@
 	paths := strings.Split(pathList, ",")
 	loaded := len(paths)
 	for _, path := range paths {
-		var err error
+		var (
+			mg  fetch.ModuleGetter
+			err error
+		)
 		if *gopathMode {
-			err = ds.LoadFromGOPATH(ctx, path)
+			mg, err = localdatasource.NewGOPATHModuleGetter(path)
 		} else {
-			err = ds.Load(ctx, path)
+			mg, err = fetch.NewDirectoryModuleGetter("", path)
 		}
 		if err != nil {
 			log.Error(ctx, err)
 			loaded--
+		} else {
+			ds.AddModuleGetter(mg)
 		}
 	}
 
diff --git a/internal/localdatasource/datasource.go b/internal/localdatasource/datasource.go
index 69a0c6c..0188e34 100644
--- a/internal/localdatasource/datasource.go
+++ b/internal/localdatasource/datasource.go
@@ -9,13 +9,13 @@
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"os"
 	"path"
 	"path/filepath"
 	"strings"
 	"sync"
-	"time"
 
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/derrors"
@@ -29,54 +29,78 @@
 	sourceClient *source.Client
 
 	mu            sync.Mutex
+	getters       []fetch.ModuleGetter
 	loadedModules map[string]*internal.Module
 }
 
 // New creates and returns a new local datasource that bypasses license
 // checks by default.
-func New() *DataSource {
+func New(sc *source.Client) *DataSource {
 	return &DataSource{
-		sourceClient:  source.NewClient(1 * time.Minute),
+		sourceClient:  sc,
 		loadedModules: make(map[string]*internal.Module),
 	}
 }
 
-// Load loads a module from the given local path. Loading is required before
-// being able to display the module.
-func (ds *DataSource) Load(ctx context.Context, localPath string) (err error) {
-	defer derrors.Wrap(&err, "Load(%q)", localPath)
-	return ds.fetch(ctx, "", localPath)
+// AddModuleGetter adds a module getter to the DataSource. To look up a module,
+// the getters are tried in the order they were added until the desired module
+// is found.
+func (ds *DataSource) AddModuleGetter(g fetch.ModuleGetter) {
+	ds.mu.Lock()
+	defer ds.mu.Unlock()
+	ds.getters = append(ds.getters, g)
 }
 
-// LoadFromGOPATH loads a module from GOPATH using the given import path. The full
-// path of the module should be GOPATH/src/importPath. If several GOPATHs exist, the
-// module is loaded from the first one that contains the import path. Loading is required
-// before being able to display the module.
-func (ds *DataSource) LoadFromGOPATH(ctx context.Context, importPath string) (err error) {
-	defer derrors.Wrap(&err, "LoadFromGOPATH(%q)", importPath)
-
-	path := getFullPath(importPath)
-	if path == "" {
-		return fmt.Errorf("path %s doesn't exist: %w", importPath, derrors.NotFound)
+// getModule gets the module at the given path and version. It first checks the
+// cache, and if it isn't there it then tries to fetch it.
+func (ds *DataSource) getModule(ctx context.Context, path, version string) (*internal.Module, error) {
+	if m := ds.getFromCache(path, version); m != nil {
+		return m, nil
 	}
-
-	return ds.fetch(ctx, importPath, path)
+	m, err := ds.fetch(ctx, path, version)
+	if err != nil {
+		return nil, err
+	}
+	ds.mu.Lock()
+	defer ds.mu.Unlock()
+	ds.loadedModules[m.ModulePath+"@"+m.Version] = m
+	return m, nil
 }
 
-// fetch fetches a module using FetchLocalModule and adds it to the datasource.
-// If the fetching fails, an error is returned.
-func (ds *DataSource) fetch(ctx context.Context, modulePath, localPath string) error {
-	fr := fetch.FetchLocalModule(ctx, modulePath, localPath, ds.sourceClient)
-	if fr.Error != nil {
-		return fr.Error
+// getFromCache returns a module from the cache if it is present, and nil otherwise.
+func (ds *DataSource) getFromCache(path, version string) *internal.Module {
+	ds.mu.Lock()
+	defer ds.mu.Unlock()
+	// Look for an exact match first.
+	if m := ds.loadedModules[path+"@"+version]; m != nil {
+		return m
 	}
+	// Look for the module path with LocalVersion, as for a directory-based or GOPATH-mode module.
+	return ds.loadedModules[path+"@"+fetch.LocalVersion]
+}
 
-	fr.Module.IsRedistributable = true
-	for _, unit := range fr.Module.Units {
+// fetch fetches a module using the configured ModuleGetters.
+// It tries each getter in turn until it finds one that has the module.
+func (ds *DataSource) fetch(ctx context.Context, modulePath, version string) (*internal.Module, error) {
+	for _, g := range ds.getters {
+		fr := fetch.FetchModule(ctx, modulePath, version, g, ds.sourceClient)
+		if fr.Error == nil {
+			adjust(fr.Module)
+			return fr.Module, nil
+		}
+		if !errors.Is(fr.Error, derrors.NotFound) {
+			return nil, fr.Error
+		}
+	}
+	return nil, fmt.Errorf("%s@%s: %w", modulePath, version, derrors.NotFound)
+}
+
+func adjust(m *internal.Module) {
+	m.IsRedistributable = true
+	for _, unit := range m.Units {
 		unit.IsRedistributable = true
 	}
-
-	for _, unit := range fr.Module.Units {
+	for _, unit := range m.Units {
 		for _, d := range unit.Documentation {
 			unit.BuildContexts = append(unit.BuildContexts, internal.BuildContext{
 				GOOS:   d.GOOS,
@@ -84,11 +108,18 @@
 			})
 		}
 	}
+}
 
-	ds.mu.Lock()
-	defer ds.mu.Unlock()
-	ds.loadedModules[fr.ModulePath] = fr.Module
-	return nil
+// NewGOPATHModuleGetter returns a module getter that uses the GOPATH
+// environment variable to find the module with the given import path.
+func NewGOPATHModuleGetter(importPath string) (_ fetch.ModuleGetter, err error) {
+	defer derrors.Wrap(&err, "NewGOPATHModuleGetter(%q)", importPath)
+
+	dir := getFullPath(importPath)
+	if dir == "" {
+		return nil, fmt.Errorf("path %s doesn't exist: %w", importPath, derrors.NotFound)
+	}
+	return fetch.NewDirectoryModuleGetter(importPath, dir)
 }
 
 // getFullPath takes an import path, tests it relative to each GOPATH, and returns
@@ -111,51 +142,30 @@
 func (ds *DataSource) GetUnit(ctx context.Context, pathInfo *internal.UnitMeta, fields internal.FieldSet, bc internal.BuildContext) (_ *internal.Unit, err error) {
 	defer derrors.Wrap(&err, "GetUnit(%q, %q)", pathInfo.Path, pathInfo.ModulePath)
 
-	modulepath := pathInfo.ModulePath
-	path := pathInfo.Path
-
-	ds.mu.Lock()
-	defer ds.mu.Unlock()
-	if ds.loadedModules[modulepath] == nil {
-		return nil, fmt.Errorf("%s not loaded: %w", modulepath, derrors.NotFound)
+	module, err := ds.getModule(ctx, pathInfo.ModulePath, pathInfo.Version)
+	if err != nil {
+		return nil, err
 	}
-
-	module := ds.loadedModules[modulepath]
 	for _, unit := range module.Units {
-		if unit.Path == path {
+		if unit.Path == pathInfo.Path {
 			return unit, nil
 		}
 	}
 
-	return nil, fmt.Errorf("%s not found: %w", path, derrors.NotFound)
+	return nil, fmt.Errorf("import path %s not found in module %s: %w", pathInfo.Path, pathInfo.ModulePath, derrors.NotFound)
 }
 
 // GetUnitMeta returns information about a path.
 func (ds *DataSource) GetUnitMeta(ctx context.Context, path, requestedModulePath, requestedVersion string) (_ *internal.UnitMeta, err error) {
 	defer derrors.Wrap(&err, "GetUnitMeta(%q, %q, %q)", path, requestedModulePath, requestedVersion)
 
-	if requestedModulePath == internal.UnknownModulePath {
-		requestedModulePath, err = ds.findModule(path)
-		if err != nil {
-			return nil, err
-		}
+	module, err := ds.findModule(ctx, path, requestedModulePath, requestedVersion)
+	if err != nil {
+		return nil, err
 	}
-
-	ds.mu.Lock()
-	module := ds.loadedModules[requestedModulePath]
-	ds.mu.Unlock()
-	if module == nil {
-		return nil, fmt.Errorf("%s not loaded: %w", requestedModulePath, derrors.NotFound)
-	}
-
 	um := &internal.UnitMeta{
-		Path: path,
-		ModuleInfo: internal.ModuleInfo{
-			ModulePath:        requestedModulePath,
-			Version:           fetch.LocalVersion,
-			CommitTime:        fetch.LocalCommitTime,
-			IsRedistributable: module.IsRedistributable,
-		},
+		Path:       path,
+		ModuleInfo: module.ModuleInfo,
 	}
 
 	for _, u := range module.Units {
@@ -168,23 +178,26 @@
 	return um, nil
 }
 
-// findModule finds the longest module path in loadedModules containing the given
-// package path. It iteratively checks parent directories to find an import path.
-// Returns an error if no module is found.
-func (ds *DataSource) findModule(pkgPath string) (_ string, err error) {
+// findModule finds the module with longest module path containing the given
+// package path. It returns an error if no module is found.
+func (ds *DataSource) findModule(ctx context.Context, pkgPath, modulePath, version string) (_ *internal.Module, err error) {
 	defer derrors.Wrap(&err, "findModule(%q)", pkgPath)
 
-	pkgPath = strings.TrimLeft(pkgPath, "/")
-
-	ds.mu.Lock()
-	defer ds.mu.Unlock()
-	for modulePath := pkgPath; modulePath != "" && modulePath != "."; modulePath = path.Dir(modulePath) {
-		if ds.loadedModules[modulePath] != nil {
-			return modulePath, nil
-		}
+	if modulePath != internal.UnknownModulePath {
+		return ds.getModule(ctx, modulePath, version)
 	}
 
-	return "", fmt.Errorf("%s not loaded: %w", pkgPath, derrors.NotFound)
+	pkgPath = strings.TrimLeft(pkgPath, "/")
+	for modulePath := pkgPath; modulePath != "" && modulePath != "."; modulePath = path.Dir(modulePath) {
+		m, err := ds.getModule(ctx, modulePath, version)
+		if err == nil {
+			return m, nil
+		}
+		if !errors.Is(err, derrors.NotFound) {
+			return nil, err
+		}
+	}
+	return nil, fmt.Errorf("could not find module for import path %s: %w", pkgPath, derrors.NotFound)
 }
 
 // GetLatestInfo is not implemented.
diff --git a/internal/localdatasource/datasource_test.go b/internal/localdatasource/datasource_test.go
index b83bb48..0a96178 100644
--- a/internal/localdatasource/datasource_test.go
+++ b/internal/localdatasource/datasource_test.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"errors"
+	"log"
 	"os"
 	"testing"
 	"time"
@@ -18,23 +19,18 @@
 	"golang.org/x/pkgsite/internal/fetch"
 	"golang.org/x/pkgsite/internal/godoc/dochtml"
 	"golang.org/x/pkgsite/internal/licenses"
+	"golang.org/x/pkgsite/internal/source"
 	"golang.org/x/pkgsite/internal/stdlib"
 	"golang.org/x/pkgsite/internal/testing/testhelper"
 )
 
-var (
-	ctx        context.Context
-	cancel     func()
-	datasource *DataSource
-)
+func TestMain(m *testing.M) {
+	os.Exit(run(m))
+}
 
-func setup(t *testing.T) (context.Context, func(), *DataSource, error) {
-	t.Helper()
+var datasource *DataSource
 
-	// Setup only once.
-	if datasource != nil {
-		return ctx, cancel, datasource, nil
-	}
+func run(m *testing.M) int {
 	licenses.OmitExceptions = true
 	modules := []map[string]string{
 		{
@@ -82,31 +78,28 @@
 	}
 
 	dochtml.LoadTemplates(template.TrustedSourceFromConstant("../../static/doc"))
-	datasource = New()
-	ctx, cancel = context.WithTimeout(context.Background(), 20*time.Second)
+	datasource = New(source.NewClientForTesting())
 	for _, module := range modules {
 		directory, err := testhelper.CreateTestDirectory(module)
 		if err != nil {
-			return ctx, func() { cancel() }, nil, err
+			log.Fatal(err)
 		}
 		defer os.RemoveAll(directory)
 
-		err = datasource.Load(ctx, directory)
+		mg, err := fetch.NewDirectoryModuleGetter("", directory)
 		if err != nil {
-			return ctx, func() { cancel() }, nil, err
+			log.Fatal(err)
 		}
+		datasource.AddModuleGetter(mg)
 	}
-
-	return ctx, func() { cancel() }, datasource, nil
+	return m.Run()
 }
 
 func TestGetUnitMeta(t *testing.T) {
-	ctx, cancel, ds, err := setup(t)
-	if err != nil {
-		t.Fatalf("setup failed: %s", err.Error())
-	}
+	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
 	defer cancel()
 
+	sourceInfo := source.NewGitHubInfo("https://github.com/my/module", "", "v0.0.0")
 	for _, test := range []struct {
 		path, modulePath string
 		want             *internal.UnitMeta
@@ -122,6 +115,8 @@
 					Version:           fetch.LocalVersion,
 					CommitTime:        fetch.LocalCommitTime,
 					IsRedistributable: true,
+					HasGoMod:          true,
+					SourceInfo:        sourceInfo,
 				},
 				IsRedistributable: true,
 			},
@@ -137,6 +132,8 @@
 					Version:           fetch.LocalVersion,
 					CommitTime:        fetch.LocalCommitTime,
 					IsRedistributable: true,
+					HasGoMod:          true,
+					SourceInfo:        sourceInfo,
 				},
 				IsRedistributable: true,
 			},
@@ -152,6 +149,8 @@
 					IsRedistributable: true,
 					Version:           fetch.LocalVersion,
 					CommitTime:        fetch.LocalCommitTime,
+					HasGoMod:          true,
+					SourceInfo:        sourceInfo,
 				},
 				IsRedistributable: true,
 			},
@@ -168,6 +167,8 @@
 					Version:           fetch.LocalVersion,
 					CommitTime:        fetch.LocalCommitTime,
 					IsRedistributable: true,
+					HasGoMod:          true,
+					SourceInfo:        sourceInfo,
 				},
 			},
 		},
@@ -179,11 +180,11 @@
 		{
 			path:       "net/http",
 			modulePath: stdlib.ModulePath,
-			wantErr:    derrors.NotFound,
+			wantErr:    derrors.InvalidArgument,
 		},
 	} {
 		t.Run(test.path, func(t *testing.T) {
-			got, err := ds.GetUnitMeta(ctx, test.path, test.modulePath, fetch.LocalVersion)
+			got, err := datasource.GetUnitMeta(ctx, test.path, test.modulePath, fetch.LocalVersion)
 			if test.wantErr != nil {
 				if !errors.Is(err, test.wantErr) {
 					t.Errorf("GetUnitMeta(%q, %q): %v; wantErr = %v)", test.path, test.modulePath, err, test.wantErr)
@@ -192,7 +193,7 @@
 				if err != nil {
 					t.Fatal(err)
 				}
-				if diff := cmp.Diff(test.want, got); diff != "" {
+				if diff := cmp.Diff(test.want, got, cmp.AllowUnexported(source.Info{})); diff != "" {
 					t.Errorf("mismatch (-want +got):\n%s", diff)
 
 				}
@@ -205,10 +206,7 @@
 	// This is a simple test to verify that data is fetched correctly. The
 	// return value of FetchResult is tested in internal/fetch so no need
 	// to repeat it.
-	ctx, cancel, ds, err := setup(t)
-	if err != nil {
-		t.Fatalf("setup failed: %s", err.Error())
-	}
+	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
 	defer cancel()
 
 	for _, test := range []struct {
@@ -240,7 +238,7 @@
 				Path:       test.path,
 				ModuleInfo: internal.ModuleInfo{ModulePath: test.modulePath},
 			}
-			got, err := ds.GetUnit(ctx, um, 0, internal.BuildContext{})
+			got, err := datasource.GetUnit(ctx, um, 0, internal.BuildContext{})
 			if !test.wantLoaded {
 				if err == nil {
 					t.Fatalf("returned not loaded module %q", test.path)