internal/frontend: add getImportedByCount

getImportedByCount is added, which returns the string to be displayed
for a given datasource and imported by count.

Change-Id: If0f64999cb86de1045bbfd8fc179c9571da8f7dc
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/264321
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/frontend/imports.go b/internal/frontend/imports.go
index 84832bc..50a60cd 100644
--- a/internal/frontend/imports.go
+++ b/internal/frontend/imports.go
@@ -75,7 +75,9 @@
 	TotalIsExact bool // if false, then there may be more than Total
 }
 
-const importedByLimit = 20001
+// importedByLimit is the maximum number of importers displayed on the imported
+// by page.
+var importedByLimit = 20001
 
 // etchImportedByDetails fetches importers for the package version specified by
 // path and version from the database and returns a ImportedByDetails.
diff --git a/internal/frontend/unit_main.go b/internal/frontend/unit_main.go
index 82f0e23..0258e02 100644
--- a/internal/frontend/unit_main.go
+++ b/internal/frontend/unit_main.go
@@ -13,6 +13,7 @@
 
 	"github.com/google/safehtml"
 	"golang.org/x/pkgsite/internal"
+	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/experiment"
 	"golang.org/x/pkgsite/internal/godoc"
 	"golang.org/x/pkgsite/internal/log"
@@ -94,26 +95,6 @@
 		return nil, err
 	}
 
-	importedByCount := strconv.Itoa(unit.NumImportedBy)
-	if !experiment.IsActive(ctx, internal.ExperimentGetUnitWithOneQuery) {
-		// importedByCount is not supported when using a datasource proxy.
-		importedByCount = "0"
-		db, ok := ds.(*postgres.DB)
-		if ok {
-			importedBy, err := db.GetImportedBy(ctx, um.Path, um.ModulePath, importedByLimit)
-			if err != nil {
-				return nil, err
-			}
-			// If we reached the query limit, then we don't know the total
-			// and we'll indicate that with a '+'. For example, if the limit
-			// is 101 and we get 101 results, then we'll show '100+ Imported by'.
-			importedByCount = strconv.Itoa(len(importedBy))
-			if len(importedBy) == importedByLimit {
-				importedByCount = strconv.Itoa(len(importedBy)-1) + "+"
-			}
-		}
-	}
-
 	nestedModules, err := getNestedModules(ctx, ds, um)
 	if err != nil {
 		return nil, err
@@ -126,6 +107,10 @@
 	if err != nil {
 		return nil, err
 	}
+	importedByCount, err := getImportedByCount(ctx, ds, unit)
+	if err != nil {
+		return nil, err
+	}
 
 	var (
 		docBody, docOutline, mobileOutline safehtml.HTML
@@ -269,3 +254,35 @@
 	}
 	return u.Documentation.HTML
 }
+
+// getImportedByCount fetches the imported by count for the unit and returns a
+// string to be displayed. If the datasource does not support imported by, it
+// will return N/A.
+func getImportedByCount(ctx context.Context, ds internal.DataSource, unit *internal.Unit) (_ string, err error) {
+	defer derrors.Wrap(&err, "getImportedByCount(%q, %q, %q)", unit.Path, unit.ModulePath, unit.Version)
+	defer middleware.ElapsedStat(ctx, "getImportedByCount")()
+
+	db, ok := ds.(*postgres.DB)
+	if !ok {
+		return "N/A", nil
+	}
+
+	var count int
+	if experiment.IsActive(ctx, internal.ExperimentGetUnitWithOneQuery) {
+		count = unit.NumImportedBy
+	} else {
+		importedBy, err := db.GetImportedBy(ctx, unit.Path, unit.ModulePath, importedByLimit)
+		if err != nil {
+			return "", err
+		}
+		count = len(importedBy)
+	}
+	// If we reached the query limit, then we might know the total, but we
+	// won't display past importedByLimit results. Indicate that with a '+'.
+	// For example, if the limit is 101 and we get 101 results, then we'll show
+	// '100+ Imported by'.
+	if count >= importedByLimit {
+		return strconv.Itoa(count-1) + "+", nil
+	}
+	return strconv.Itoa(count), nil
+}
diff --git a/internal/frontend/unit_main_test.go b/internal/frontend/unit_main_test.go
index 72a744a..d7a847f 100644
--- a/internal/frontend/unit_main_test.go
+++ b/internal/frontend/unit_main_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"context"
+	"path"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -86,3 +87,70 @@
 		})
 	}
 }
+
+func TestGetImportedByCount(t *testing.T) {
+	defer postgres.ResetTestDB(testDB, t)
+
+	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
+	defer cancel()
+
+	newModule := func(modPath string, pkgs ...*internal.Unit) *internal.Module {
+		m := sample.LegacyModule(modPath, sample.VersionString)
+		for _, p := range pkgs {
+			sample.AddUnit(m, p)
+		}
+		return m
+	}
+
+	pkg1 := sample.UnitForPackage("path.to/foo", "bar")
+	pkg2 := sample.UnitForPackage("path2.to/foo", "bar2")
+	pkg2.Imports = []string{pkg1.Path}
+
+	pkg3 := sample.UnitForPackage("path3.to/foo", "bar3")
+	pkg3.Imports = []string{pkg2.Path, pkg1.Path}
+
+	testModules := []*internal.Module{
+		newModule("path.to/foo", pkg1),
+		newModule("path2.to/foo", pkg2),
+		newModule("path3.to/foo", pkg3),
+	}
+
+	for _, m := range testModules {
+		if err := testDB.InsertModule(ctx, m); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	importedByLimit = 2
+	for _, tc := range []struct {
+		pkg  *internal.Unit
+		want string
+	}{
+		{
+			pkg:  pkg3,
+			want: "0",
+		},
+		{
+			pkg:  pkg2,
+			want: "1",
+		},
+		{
+			pkg:  pkg1,
+			want: "1+",
+		},
+	} {
+		t.Run(tc.pkg.Path, func(t *testing.T) {
+			otherVersion := newModule(path.Dir(tc.pkg.Path), tc.pkg)
+			otherVersion.Version = "v1.0.5"
+			pkg := otherVersion.Units[1]
+			got, err := getImportedByCount(ctx, testDB, pkg)
+			if err != nil {
+				t.Fatalf("getImportedByCount(ctx, db, %q) = %v err = %v, want %v",
+					tc.pkg.Path, got, err, tc.want)
+			}
+			if diff := cmp.Diff(tc.want, got); diff != "" {
+				t.Errorf("getImportedByCount(ctx, db, %q) mismatch (-want +got):\n%s", tc.pkg.Path, diff)
+			}
+		})
+	}
+}