internal/testing/fakedatasource: add a fake for internal.DataSource

This CL creates a new FakeDataSource type that has enough of an
implementation to replace some of the uses of the postgres database
(though not yet all) in the internal/frontend tests. Parts of the
implementation have been taken from the pre-existing datasource
implementation in FetchDataSource and postgres.DB.

For golang/go#61399

Change-Id: I761178d8d4f65457738c46c4733f1231ee32c645
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/520615
Run-TryBot: Michael Matloob <matloob@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
kokoro-CI: kokoro <noreply+kokoro@google.com>
diff --git a/internal/frontend/directory_test.go b/internal/frontend/directory_test.go
index a144e02..1f0c0f9 100644
--- a/internal/frontend/directory_test.go
+++ b/internal/frontend/directory_test.go
@@ -10,14 +10,14 @@
 
 	"github.com/google/go-cmp/cmp"
 	"golang.org/x/pkgsite/internal"
-	"golang.org/x/pkgsite/internal/postgres"
+	"golang.org/x/pkgsite/internal/testing/fakedatasource"
 	"golang.org/x/pkgsite/internal/testing/sample"
 )
 
 func TestGetNestedModules(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 	defer cancel()
-	defer postgres.ResetTestDB(testDB, t)
+	fds := fakedatasource.New()
 
 	for _, m := range []*internal.Module{
 		sample.Module("cloud.google.com/go", "v0.46.2", "storage", "spanner", "pubsub"),
@@ -30,7 +30,7 @@
 		sample.Module("cloud.google.com/go/storage/v9/module", "v9.0.0", sample.Suffix),
 		sample.Module("cloud.google.com/go/v2", "v2.0.0", "storage", "spanner", "pubsub"),
 	} {
-		postgres.MustInsertModule(ctx, t, testDB, m)
+		fds.MustInsertModule(m)
 	}
 
 	for _, test := range []struct {
@@ -103,7 +103,7 @@
 		},
 	} {
 		t.Run(test.modulePath, func(t *testing.T) {
-			got, err := getNestedModules(ctx, testDB, &internal.UnitMeta{
+			got, err := getNestedModules(ctx, fds, &internal.UnitMeta{
 				Path:       test.modulePath,
 				ModuleInfo: internal.ModuleInfo{ModulePath: test.modulePath},
 			}, test.subdirectories)
diff --git a/internal/frontend/imports_test.go b/internal/frontend/imports_test.go
index 25873b2..1aa5eea 100644
--- a/internal/frontend/imports_test.go
+++ b/internal/frontend/imports_test.go
@@ -12,6 +12,7 @@
 	"github.com/google/go-cmp/cmp"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/postgres"
+	"golang.org/x/pkgsite/internal/testing/fakedatasource"
 	"golang.org/x/pkgsite/internal/testing/sample"
 )
 
@@ -44,7 +45,7 @@
 		},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			defer postgres.ResetTestDB(testDB, t)
+			fds := fakedatasource.New()
 
 			ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 			defer cancel()
@@ -54,9 +55,9 @@
 			pkg := module.Units[1]
 			pkg.Imports = test.imports
 
-			postgres.MustInsertModule(ctx, t, testDB, module)
+			fds.MustInsertModule(module)
 
-			got, err := fetchImportsDetails(ctx, testDB, pkg.Path, pkg.ModulePath, pkg.Version)
+			got, err := fetchImportsDetails(ctx, fds, pkg.Path, pkg.ModulePath, pkg.Version)
 			if err != nil {
 				t.Fatalf("fetchImportsDetails(ctx, db, %q, %q) = %v err = %v, want %v",
 					module.Units[1].Path, module.Version, got, err, test.wantDetails)
diff --git a/internal/frontend/license_test.go b/internal/frontend/license_test.go
index af09140..539b0cd 100644
--- a/internal/frontend/license_test.go
+++ b/internal/frontend/license_test.go
@@ -15,8 +15,8 @@
 	"github.com/google/safehtml"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/licenses"
-	"golang.org/x/pkgsite/internal/postgres"
 	"golang.org/x/pkgsite/internal/stdlib"
+	"golang.org/x/pkgsite/internal/testing/fakedatasource"
 	"golang.org/x/pkgsite/internal/testing/sample"
 	"golang.org/x/pkgsite/internal/testing/testhelper"
 )
@@ -78,11 +78,11 @@
 	// github.com/valid/module_name/A/B
 	testModule.Units[2].Licenses = []*licenses.Metadata{mit, bsd}
 
-	defer postgres.ResetTestDB(testDB, t)
+	fds := fakedatasource.New()
 	ctx := context.Background()
-	postgres.MustInsertModule(ctx, t, testDB, testModule)
-	postgres.MustInsertModule(ctx, t, testDB, stdlibModule)
-	postgres.MustInsertModule(ctx, t, testDB, crlfModule)
+	fds.MustInsertModule(testModule)
+	fds.MustInsertModule(stdlibModule)
+	fds.MustInsertModule(crlfModule)
 	for _, test := range []struct {
 		err                                 error
 		name, fullPath, modulePath, version string
@@ -141,7 +141,7 @@
 		t.Run(test.name, func(t *testing.T) {
 			wantDetails := &LicensesDetails{Licenses: transformLicenses(
 				test.modulePath, test.version, test.want)}
-			got, err := fetchLicensesDetails(ctx, testDB, &internal.UnitMeta{
+			got, err := fetchLicensesDetails(ctx, fds, &internal.UnitMeta{
 				Path: test.fullPath,
 				ModuleInfo: internal.ModuleInfo{
 					ModulePath: test.modulePath,
diff --git a/internal/frontend/main_test.go b/internal/frontend/main_test.go
index 092c09e..69fb596 100644
--- a/internal/frontend/main_test.go
+++ b/internal/frontend/main_test.go
@@ -5,20 +5,16 @@
 package frontend
 
 import (
-	"context"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
 	"golang.org/x/pkgsite/internal"
-	"golang.org/x/pkgsite/internal/postgres"
+	"golang.org/x/pkgsite/internal/testing/fakedatasource"
 	"golang.org/x/pkgsite/internal/testing/sample"
 )
 
 func TestGetImportedByCount(t *testing.T) {
-	defer postgres.ResetTestDB(testDB, t)
-
-	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
-	defer cancel()
+	fds := fakedatasource.New()
 
 	newModule := func(modPath string, imports []string, numImportedBy int) *internal.Module {
 		m := sample.Module(modPath, sample.VersionString, "")
@@ -34,7 +30,7 @@
 	mod2 := newModule(p2, []string{p1}, 1)
 	mod3 := newModule(p3, []string{p1, p2}, 0)
 	for _, m := range []*internal.Module{mod1, mod2, mod3} {
-		postgres.MustInsertModule(ctx, t, testDB, m)
+		fds.MustInsertModule(m)
 	}
 
 	for _, test := range []struct {
diff --git a/internal/frontend/search_test.go b/internal/frontend/search_test.go
index 550e818..fa147b4 100644
--- a/internal/frontend/search_test.go
+++ b/internal/frontend/search_test.go
@@ -26,6 +26,7 @@
 	"golang.org/x/pkgsite/internal/licenses"
 	"golang.org/x/pkgsite/internal/osv"
 	"golang.org/x/pkgsite/internal/postgres"
+	"golang.org/x/pkgsite/internal/testing/fakedatasource"
 	"golang.org/x/pkgsite/internal/testing/sample"
 	"golang.org/x/pkgsite/internal/vuln"
 	"golang.org/x/text/language"
@@ -33,15 +34,14 @@
 )
 
 func TestDetermineSearchAction(t *testing.T) {
-	ctx := context.Background()
-	defer postgres.ResetTestDB(testDB, t)
 	golangTools := sample.Module("golang.org/x/tools", sample.VersionString, "internal/lsp")
 	std := sample.Module("std", sample.VersionString,
 		"cmd/go", "cmd/go/internal/auth", "fmt")
 	modules := []*internal.Module{golangTools, std}
 
+	fds := fakedatasource.New()
 	for _, v := range modules {
-		postgres.MustInsertModule(ctx, t, testDB, v)
+		fds.MustInsertModule(v)
 	}
 	vc, err := vuln.NewInMemoryClient(testEntries)
 	if err != nil {
@@ -144,7 +144,7 @@
 	} {
 		t.Run(test.name, func(t *testing.T) {
 			req := buildSearchRequest(t, test.method, test.query)
-			var ds internal.DataSource = testDB
+			var ds internal.DataSource = fds
 			if test.ds != nil {
 				ds = test.ds
 			}
@@ -234,7 +234,7 @@
 func TestFetchSearchPage(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
 	defer cancel()
-	defer postgres.ResetTestDB(testDB, t)
+	fds := fakedatasource.New()
 
 	var (
 		now       = sample.NowTruncated()
@@ -327,7 +327,7 @@
 	}
 
 	for _, m := range []*internal.Module{moduleFoo, moduleBar} {
-		postgres.MustInsertModule(ctx, t, testDB, m)
+		fds.MustInsertModule(m)
 	}
 
 	for _, test := range []struct {
@@ -398,7 +398,7 @@
 		},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, vc)
+			got, err := fetchSearchPage(ctx, fds, test.query, "", paginationParams{limit: 20, page: 1}, false, vc)
 			if err != nil {
 				t.Fatalf("fetchSearchPage(db, %q): %v", test.query, err)
 			}
@@ -420,13 +420,13 @@
 	for _, test := range []struct {
 		name string
 		tag  language.Tag
-		in   postgres.SearchResult
+		in   internal.SearchResult
 		want SearchResult
 	}{
 		{
 			name: "basic",
 			tag:  language.English,
-			in: postgres.SearchResult{
+			in: internal.SearchResult{
 				Name:          "pkg",
 				PackagePath:   "m.com/pkg",
 				ModulePath:    "m.com",
@@ -445,7 +445,7 @@
 		{
 			name: "command",
 			tag:  language.English,
-			in: postgres.SearchResult{
+			in: internal.SearchResult{
 				Name:          "main",
 				PackagePath:   "m.com/cmd",
 				ModulePath:    "m.com",
@@ -465,7 +465,7 @@
 		{
 			name: "stdlib",
 			tag:  language.English,
-			in: postgres.SearchResult{
+			in: internal.SearchResult{
 				Name:        "math",
 				PackagePath: "math",
 				ModulePath:  "std",
@@ -484,7 +484,7 @@
 		{
 			name: "German",
 			tag:  language.German,
-			in: postgres.SearchResult{
+			in: internal.SearchResult{
 				Name:          "pkg",
 				PackagePath:   "m.com/pkg",
 				ModulePath:    "m.com",
@@ -525,8 +525,9 @@
 		"cmd/go", "cmd/go/internal/auth", "fmt")
 	modules := []*internal.Module{golangTools, std}
 
+	fds := fakedatasource.New()
 	for _, v := range modules {
-		postgres.MustInsertModule(ctx, t, testDB, v)
+		fds.MustInsertModule(v)
 	}
 	for _, test := range []struct {
 		name  string
@@ -551,7 +552,7 @@
 		{"CVE alias", "CVE-2022-32190", "", searchModePackage},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			if got := searchRequestRedirectPath(ctx, testDB, test.query, test.mode, true); got != test.want {
+			if got := searchRequestRedirectPath(ctx, fds, test.query, test.mode, true); got != test.want {
 				t.Errorf("searchRequestRedirectPath(ctx, %q) = %q; want = %q", test.query, got, test.want)
 			}
 		})
diff --git a/internal/testing/fakedatasource/fakedatasource.go b/internal/testing/fakedatasource/fakedatasource.go
new file mode 100644
index 0000000..a06e6c9
--- /dev/null
+++ b/internal/testing/fakedatasource/fakedatasource.go
@@ -0,0 +1,307 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fakedatasource provides a fake implementation of the internal.DataSource interface.
+package fakedatasource
+
+import (
+	"context"
+	"fmt"
+	"sort"
+	"strings"
+
+	"golang.org/x/mod/module"
+	"golang.org/x/mod/semver"
+	"golang.org/x/pkgsite/internal"
+	"golang.org/x/pkgsite/internal/derrors"
+	"golang.org/x/pkgsite/internal/licenses"
+	"golang.org/x/pkgsite/internal/version"
+)
+
+// FakeDataSource provides a fake implementation of the internal.DataSource interface.
+type FakeDataSource struct {
+	modules map[module.Version]*internal.Module
+}
+
+// New returns an initialized FakeDataSource.
+func New() *FakeDataSource {
+	return &FakeDataSource{modules: make(map[module.Version]*internal.Module)}
+}
+
+// InsertModule adds the module to the FakeDataSource.
+func (ds *FakeDataSource) MustInsertModule(m *internal.Module) {
+	if m != nil {
+		for _, u := range m.Units {
+			ds.populateUnitSubdirectories(u, m)
+
+			// Make license info consistent.
+			if u.Licenses != nil {
+				// Sort licenses as postgres database does.
+				sort.Slice(u.Licenses, func(i, j int) bool {
+					return compareLicenses(u.Licenses[i], u.Licenses[j])
+				})
+				// Make sure LicenseContents match up with Licenses
+				u.LicenseContents = nil
+				for _, ul := range u.Licenses {
+					for _, ml := range m.Licenses {
+						if sameLicense(*ul, *ml.Metadata) {
+							u.LicenseContents = append(u.LicenseContents, ml)
+						}
+					}
+				}
+			}
+		}
+	}
+
+	ds.modules[module.Version{Path: m.ModulePath, Version: m.Version}] = m
+}
+
+// compareLicenses reports whether i < j according to our license sorting
+// semantics. This is what the postgres database uses to sort licenses.
+func compareLicenses(i, j *licenses.Metadata) bool {
+	if len(strings.Split(i.FilePath, "/")) > len(strings.Split(j.FilePath, "/")) {
+		return true
+	}
+	return i.FilePath < j.FilePath
+}
+
+func sameLicense(a, b licenses.Metadata) bool {
+	return a.FilePath == b.FilePath
+}
+
+func (ds *FakeDataSource) populateUnitSubdirectories(u *internal.Unit, m *internal.Module) {
+	p := u.Path + "/"
+	for _, u2 := range m.Units {
+		if strings.HasPrefix(u2.Path, p) || u.Path == "std" {
+			var syn string
+			if len(u2.Documentation) > 0 {
+				syn = u2.Documentation[0].Synopsis
+			}
+			u.Subdirectories = append(u.Subdirectories, &internal.PackageMeta{
+				Path:              u2.Path,
+				Name:              u2.Name,
+				Synopsis:          syn,
+				IsRedistributable: u2.IsRedistributable,
+				Licenses:          u2.Licenses,
+			})
+		}
+	}
+}
+
+// compareVersion returns -1 if a's version is less than b's, 0 if they're the same
+// and 1 if a's version is greater than b's.
+// It panics if they don't have the same module path with the major version
+// suffix removed.
+func compareVersion(a, b *internal.ModuleInfo) int {
+	aprefix, asuffix, _ := module.SplitPathVersion(a.ModulePath)
+	bprefix, bsuffix, _ := module.SplitPathVersion(b.ModulePath)
+	if aprefix != bprefix {
+		panic("compareVersion called for two modules with different paths")
+	}
+
+	if asuffix == bsuffix {
+		return semver.Compare(a.Version, b.Version)
+	}
+	return semver.Compare(module.PathMajorPrefix(asuffix), module.PathMajorPrefix(bsuffix))
+}
+
+// GetNestedModules returns the latest major version of all nested modules
+// given a modulePath path prefix.
+func (ds *FakeDataSource) GetNestedModules(ctx context.Context, modulePath string) ([]*internal.ModuleInfo, error) {
+	latest := map[string]*internal.ModuleInfo{}
+	for _, mod := range ds.modules {
+		if mod.ModulePath != modulePath && !strings.HasPrefix(mod.ModulePath, modulePath+"/") {
+			continue
+		}
+
+		prefix, _, _ := module.SplitPathVersion(mod.ModulePath)
+		curlatest, ok := latest[prefix]
+		if !ok {
+			latest[prefix] = &mod.ModuleInfo
+			continue
+		}
+		if compareVersion(&mod.ModuleInfo, curlatest) > 0 {
+			latest[prefix] = &mod.ModuleInfo
+		}
+	}
+	var infos []*internal.ModuleInfo
+	for _, info := range latest {
+		infos = append(infos, info)
+	}
+	sort.Slice(infos, func(i, j int) bool {
+		prefixi, _, _ := module.SplitPathVersion(infos[i].ModulePath)
+		prefixj, _, _ := module.SplitPathVersion(infos[j].ModulePath)
+		return prefixi < prefixj
+	})
+	return infos, nil
+}
+
+// GetUnit returns information about a directory, which may also be a
+// module and/or package. The module and version must both be known.
+// The BuildContext selects the documentation to read.
+func (ds *FakeDataSource) GetUnit(ctx context.Context, um *internal.UnitMeta, fields internal.FieldSet, bc internal.BuildContext) (*internal.Unit, error) {
+	m := ds.getModule(um.ModulePath, um.Version)
+	if m == nil {
+		return nil, derrors.NotFound
+	}
+	u := findUnit(m, um.Path)
+	if u == nil {
+		return nil, fmt.Errorf("import path %s not found in module %s: %w", um.Path, um.ModulePath, derrors.NotFound)
+	}
+	// Return only the Documentation matching the given BuildContext, if any.
+	// Since we cache the module and its units, we have to copy this unit before we modify it.
+	// It can be a shallow copy, since we're only modifying the Unit.Documentation field.
+	u2 := *u
+	if d := matchingDoc(u.Documentation, bc); d != nil {
+		u2.Documentation = []*internal.Documentation{d}
+	} else {
+		u2.Documentation = nil
+	}
+	return &u2, nil
+}
+
+// matchingDoc returns the Documentation that matches the given build context
+// and comes earliest in build-context order. It returns nil if there is none.
+func matchingDoc(docs []*internal.Documentation, bc internal.BuildContext) *internal.Documentation {
+	var (
+		dMin  *internal.Documentation
+		bcMin *internal.BuildContext // sorts last
+	)
+	for _, d := range docs {
+		dbc := d.BuildContext()
+		if bc.Match(dbc) && (bcMin == nil || internal.CompareBuildContexts(dbc, *bcMin) < 0) {
+			dMin = d
+			bcMin = &dbc
+		}
+	}
+	return dMin
+}
+
+// GetUnitMeta returns information about a path.
+func (ds *FakeDataSource) GetUnitMeta(ctx context.Context, path, requestedModulePath, requestedVersion string) (_ *internal.UnitMeta, err error) {
+	module := ds.findModule(path, requestedModulePath, requestedVersion)
+	if module == nil {
+		return nil, fmt.Errorf("could not find module for import path %s: %w", path, derrors.NotFound)
+	}
+	um := &internal.UnitMeta{
+		Path:       path,
+		ModuleInfo: module.ModuleInfo,
+	}
+	u := findUnit(module, path)
+	if u == nil {
+		return nil, derrors.NotFound
+	}
+	um.Name = u.Name
+	um.IsRedistributable = u.IsRedistributable
+	return um, nil
+}
+
+// findModule finds the module with longest module path containing the given
+// package path. It returns an error if no module is found.
+func (ds *FakeDataSource) findModule(pkgPath, modulePath, version string) *internal.Module {
+	if modulePath != internal.UnknownModulePath {
+		return ds.getModule(modulePath, version)
+	}
+	pkgPath = strings.TrimLeft(pkgPath, "/")
+	for _, modulePath := range internal.CandidateModulePaths(pkgPath) {
+		if m := ds.getModule(modulePath, version); m != nil {
+			return m
+		}
+
+	}
+	return nil
+}
+
+func (ds *FakeDataSource) getModule(modulePath, vers string) *internal.Module {
+	if vers == version.Latest {
+		return ds.getLatestModule(modulePath)
+	}
+
+	return ds.modules[module.Version{Path: modulePath, Version: vers}]
+}
+
+func (ds *FakeDataSource) getLatestModule(modulePath string) *internal.Module {
+	var latestVersion module.Version
+	var latestModule *internal.Module
+	for vers, mod := range ds.modules {
+		if vers.Path == modulePath &&
+			(latestVersion == (module.Version{}) ||
+				version.Later(vers.Version, latestVersion.Version)) {
+			latestVersion = vers
+			latestModule = mod
+			continue
+		}
+	}
+	if latestModule == nil {
+		return nil
+	}
+	return latestModule
+}
+
+// findUnit returns the unit with the given path in m, or nil if none.
+func findUnit(m *internal.Module, path string) *internal.Unit {
+	for _, u := range m.Units {
+		if u.Path == path {
+			return u
+		}
+	}
+	return nil
+}
+
+// GetModuleReadme is not implemented.
+func (ds *FakeDataSource) GetModuleReadme(ctx context.Context, modulePath, resolvedVersion string) (*internal.Readme, error) {
+	return nil, nil
+}
+
+// GetLatestInfo gets information about the latest versions of a unit and module.
+// See LatestInfo for documentation.
+func (ds *FakeDataSource) GetLatestInfo(ctx context.Context, unitPath, modulePath string, latestUnitMeta *internal.UnitMeta) (latest internal.LatestInfo, err error) {
+	return internal.LatestInfo{}, nil
+}
+
+// SearchSupport reports the search types supported by this datasource.
+func (ds *FakeDataSource) SearchSupport() internal.SearchSupport {
+	// internal/frontend.TestDetermineSearchAction depends on us returning FullSearch
+	// even though it doesn't depend on the search results.
+	return internal.FullSearch
+}
+
+// Search searches for packages matching the given query.
+// It's a basic search of documentation synopses only enough to satisfy unit tests.
+func (ds *FakeDataSource) Search(ctx context.Context, q string, opts internal.SearchOptions) (results []*internal.SearchResult, err error) {
+	terms := strings.Fields(q)
+
+	for _, m := range ds.modules {
+		for _, u := range m.Units {
+			var containsAllTerms bool
+			if len(terms) > 0 {
+				containsAllTerms = true
+			}
+			synopsis := ""
+			for _, d := range u.Documentation {
+				synopsis += d.Synopsis
+			}
+			for _, term := range terms {
+				containsAllTerms = containsAllTerms && strings.Contains(synopsis, term)
+			}
+			if containsAllTerms {
+				result := &internal.SearchResult{
+					Name:        u.Name,
+					PackagePath: u.Path,
+					ModulePath:  m.ModulePath,
+					Version:     m.Version,
+					Synopsis:    synopsis,
+					CommitTime:  m.CommitTime,
+					NumResults:  1,
+				}
+				for _, licence := range u.Licenses {
+					result.Licenses = append(result.Licenses, licence.Types...)
+				}
+				results = append(results, result)
+			}
+
+		}
+	}
+	return results, nil
+}