internal/frontend: refactor serveSearch for testing

Split serveSearch into two functions: one computes the action to take
(redirect, serve a page or error) and the other carries out the action.

This makes it much easier to test the search logic, since we don't
have to examine HTML output, just the searchAction struct.

Add a test that verifies much of the high-level search logic:
various errors, when to redirect, etc.

Change-Id: I56d31264fd511420c94961ed6de4e464ebafd27b
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/431176
Run-TryBot: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Jamal Carvalho <jamal@golang.org>
diff --git a/internal/frontend/search.go b/internal/frontend/search.go
index cfa22cd..1d3b8c1 100644
--- a/internal/frontend/search.go
+++ b/internal/frontend/search.go
@@ -35,22 +35,46 @@
 // /search?q=<query>. If <query> is an exact match for a package path, the user
 // will be redirected to the details page.
 func (s *Server) serveSearch(w http.ResponseWriter, r *http.Request, ds internal.DataSource) error {
+	action, err := determineSearchAction(r, ds, s.vulnClient)
+	if err != nil {
+		return err
+	}
+	if action.redirectURL != "" {
+		http.Redirect(w, r, action.redirectURL, http.StatusFound)
+		return nil
+	}
+	action.page.setBasePage(s.newBasePage(r, action.title))
+	if s.shouldServeJSON(r) {
+		return s.serveJSONPage(w, r, action.page)
+	}
+	s.servePage(r.Context(), w, action.template, action.page)
+	return nil
+}
+
+type searchAction struct {
+	redirectURL string
+	title       string
+	template    string
+	page        interface{ setBasePage(basePage) }
+}
+
+func determineSearchAction(r *http.Request, ds internal.DataSource, vulnClient vulnc.Client) (*searchAction, error) {
 	if r.Method != http.MethodGet && r.Method != http.MethodHead {
-		return &serverError{status: http.StatusMethodNotAllowed}
+		return nil, &serverError{status: http.StatusMethodNotAllowed}
 	}
 	db, ok := ds.(*postgres.DB)
 	if !ok {
 		// The proxydatasource does not support the imported by page.
-		return datasourceNotSupportedErr()
+		return nil, datasourceNotSupportedErr()
 	}
 
 	ctx := r.Context()
 	cq, filters := searchQueryAndFilters(r)
 	if !utf8.ValidString(cq) {
-		return &serverError{status: http.StatusBadRequest}
+		return nil, &serverError{status: http.StatusBadRequest}
 	}
 	if len(filters) > 1 {
-		return &serverError{
+		return nil, &serverError{
 			status: http.StatusBadRequest,
 			epage: &errorPage{
 				messageTemplate: template.MakeTrustedTemplate(
@@ -59,7 +83,7 @@
 		}
 	}
 	if len(cq) > maxSearchQueryLength {
-		return &serverError{
+		return nil, &serverError{
 			status: http.StatusBadRequest,
 			epage: &errorPage{
 				messageTemplate: template.MakeTrustedTemplate(
@@ -68,12 +92,11 @@
 		}
 	}
 	if cq == "" {
-		http.Redirect(w, r, "/", http.StatusFound)
-		return nil
+		return &searchAction{redirectURL: "/"}, nil
 	}
 	pageParams := newPaginationParams(r, defaultSearchLimit)
 	if pageParams.offset() > maxSearchOffset {
-		return &serverError{
+		return nil, &serverError{
 			status: http.StatusBadRequest,
 			epage: &errorPage{
 				messageTemplate: template.MakeTrustedTemplate(
@@ -82,7 +105,7 @@
 		}
 	}
 	if pageParams.limit > maxSearchPageSize {
-		return &serverError{
+		return nil, &serverError{
 			status: http.StatusBadRequest,
 			epage: &errorPage{
 				messageTemplate: template.MakeTrustedTemplate(
@@ -92,41 +115,26 @@
 	}
 	mode := searchMode(r)
 	if path := searchRequestRedirectPath(ctx, ds, cq, mode); path != "" {
-		http.Redirect(w, r, path, http.StatusFound)
-		return nil
+		return &searchAction{redirectURL: path}, nil
 	}
-
-	vulnListPage, redirectURL, err := searchVulnAlias(ctx, mode, cq, s.vulnClient)
-	if err != nil {
-		return err
+	action, err := searchVulnAlias(ctx, mode, cq, vulnClient)
+	if action != nil || err != nil {
+		return action, err
 	}
-	if redirectURL != "" {
-		http.Redirect(w, r, redirectURL, http.StatusFound)
-		return nil
-	}
-	if vulnListPage != nil {
-		vulnListPage.basePage = s.newBasePage(r, fmt.Sprintf("%s - Vulnerability Reports", cq))
-		if s.shouldServeJSON(r) {
-			return s.serveJSONPage(w, r, vulnListPage)
-		}
-		s.servePage(ctx, w, "vuln/list", vulnListPage)
-		return nil
-	}
-
 	var symbol string
 	if len(filters) > 0 {
 		symbol = filters[0]
 	}
 	var getVulnEntries vulnEntriesFunc
-	if s.vulnClient != nil {
-		getVulnEntries = s.vulnClient.GetByModule
+	if vulnClient != nil {
+		getVulnEntries = vulnClient.GetByModule
 	}
 	page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, getVulnEntries)
 	if err != nil {
 		// Instead of returning a 500, return a 408, since symbol searches may
 		// timeout for very popular symbols.
 		if mode == searchModeSymbol && strings.Contains(err.Error(), "i/o timeout") {
-			return &serverError{
+			return nil, &serverError{
 				status: http.StatusRequestTimeout,
 				epage: &errorPage{
 					messageTemplate: template.MakeTrustedTemplate(
@@ -134,15 +142,14 @@
 				},
 			}
 		}
-		return fmt.Errorf("fetchSearchPage(ctx, db, %q): %v", cq, err)
+		return nil, fmt.Errorf("fetchSearchPage(ctx, db, %q): %v", cq, err)
 	}
-	page.basePage = s.newBasePage(r, fmt.Sprintf("%s - Search Results", cq))
 	page.SearchMode = mode
-	if s.shouldServeJSON(r) {
-		return s.serveJSONPage(w, r, page)
-	}
-	s.servePage(ctx, w, "search", page)
-	return nil
+	return &searchAction{
+		title:    fmt.Sprintf("%s - Search Results", cq),
+		template: "search",
+		page:     page,
+	}, nil
 }
 
 const (
@@ -359,27 +366,31 @@
 	return fmt.Sprintf("/%s", requestedPath)
 }
 
-func searchVulnAlias(ctx context.Context, mode, cq string, vulnClient vulnc.Client) (_ *VulnListPage, redirectURL string, err error) {
+func searchVulnAlias(ctx context.Context, mode, cq string, vulnClient vulnc.Client) (_ *searchAction, err error) {
 	defer derrors.Wrap(&err, "searchVulnAlias(%q, %q)", mode, cq)
 
 	if mode != searchModeVuln || !isVulnAlias(cq) {
-		return nil, "", nil
+		return nil, nil
 	}
 	aliasEntries, err := vulnClient.GetByAlias(ctx, cq)
 	if err != nil {
-		return nil, "", err
+		return nil, err
 	}
 	switch len(aliasEntries) {
 	case 0:
-		return nil, "", &serverError{status: http.StatusNotFound}
-	case 1: // redirect
-		return nil, "/vuln/" + aliasEntries[0].ID, nil
+		return nil, &serverError{status: http.StatusNotFound}
+	case 1:
+		return &searchAction{redirectURL: "/vuln/" + aliasEntries[0].ID}, nil
 	default:
 		var entries []OSVEntry
 		for _, e := range aliasEntries {
 			entries = append(entries, OSVEntry{e})
 		}
-		return &VulnListPage{Entries: entries}, "", nil
+		return &searchAction{
+			title:    fmt.Sprintf("%s - Vulnerability Reports", cq),
+			template: "vuln/list",
+			page:     &VulnListPage{Entries: entries},
+		}, nil
 	}
 }
 
diff --git a/internal/frontend/search_test.go b/internal/frontend/search_test.go
index 5d6c4ac..d1e9bbf 100644
--- a/internal/frontend/search_test.go
+++ b/internal/frontend/search_test.go
@@ -7,13 +7,18 @@
 import (
 	"context"
 	"fmt"
+	"net/http"
 	"net/http/httptest"
+	"strings"
 	"testing"
 	"time"
 
+	"net/url"
+
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 	"golang.org/x/pkgsite/internal"
+	"golang.org/x/pkgsite/internal/fetchdatasource"
 	"golang.org/x/pkgsite/internal/licenses"
 	"golang.org/x/pkgsite/internal/postgres"
 	"golang.org/x/pkgsite/internal/testing/sample"
@@ -22,6 +27,150 @@
 	"golang.org/x/vuln/osv"
 )
 
+func TestDetermineSearchAction(t *testing.T) {
+	ctx := context.Background()
+	defer postgres.ResetTestDB(testDB, t)
+	golangTools := sample.Module("golang.org/x/tools", sample.VersionString, "internal/lsp")
+	std := sample.Module("std", sample.VersionString,
+		"cmd/go", "cmd/go/internal/auth", "fmt")
+	modules := []*internal.Module{golangTools, std}
+
+	for _, v := range modules {
+		postgres.MustInsertModule(ctx, t, testDB, v)
+	}
+	vc := newVulndbTestClient(testEntries)
+	for _, test := range []struct {
+		name         string
+		method       string
+		ds           internal.DataSource
+		query        string // query param part of URL
+		wantRedirect string
+		wantTemplate string
+		wantStatus   int // 0 means no error
+	}{
+		{
+			name:       "wrong method",
+			method:     "POST",
+			wantStatus: http.StatusMethodNotAllowed,
+		},
+		{
+			name:       "bad data source",
+			ds:         &fetchdatasource.FetchDataSource{},
+			wantStatus: http.StatusFailedDependency,
+		},
+		{
+			name:       "invalid query string",
+			query:      "q=\xF4\x90\x80\x80",
+			wantStatus: http.StatusBadRequest,
+		},
+		{
+			name:       "too many filters",
+			query:      "q=" + url.QueryEscape("a #b #c"),
+			wantStatus: http.StatusBadRequest,
+		},
+		{
+			name:       "query too long",
+			query:      "q=" + strings.Repeat("x", maxSearchQueryLength+1),
+			wantStatus: http.StatusBadRequest,
+		},
+		{
+			name:         "empty query",
+			wantRedirect: "/",
+		},
+		{
+			name:       "offset too large",
+			query:      "q=foo&page=100",
+			wantStatus: http.StatusBadRequest,
+		},
+		{
+			name:       "limit too large",
+			query:      "q=foo&limit=" + fmt.Sprint(maxSearchPageSize+1),
+			wantStatus: http.StatusBadRequest,
+		},
+		// Some redirections; see more at TestSearchRequestRedirectPath.
+		{
+			name:         "Go vuln report",
+			query:        "q=GO-2020-1234",
+			wantRedirect: "/vuln/GO-2020-1234?q", // ??? DO WE WANT THE "?q" ???
+		},
+		{
+			name:         "known unit",
+			query:        "q=golang.org/x/tools",
+			wantRedirect: "/golang.org/x/tools",
+		},
+		// Vuln aliases.
+		// See testEntries in vulns_test.go to understand results.
+		// See TestSearchVulnAlias in this file for more tests.
+		{
+			name:         "vuln alias single",
+			query:        "q=GHSA-aaaa-bbbb-cccc&m=vuln",
+			wantRedirect: "/vuln/GO-1990-01",
+		},
+		{
+			name:         "vuln alias multi",
+			query:        "q=CVE-2000-1&m=vuln",
+			wantTemplate: "vuln/list",
+		},
+		{
+			// We turn on vuln mode if the query matches a vuln alias.
+			name:         "vuln alias not vuln mode",
+			query:        "q=GHSA-aaaa-bbbb-cccc",
+			wantRedirect: "/vuln/GO-1990-01",
+		},
+		{
+			// An explicit mode overrides that.
+			name:         "vuln alias symbol mode",
+			query:        "q=GHSA-aaaa-bbbb-cccc?m=symbol",
+			wantTemplate: "search",
+		},
+		{
+			name:         "normal search",
+			query:        "q=foo",
+			wantTemplate: "search",
+		},
+	} {
+		t.Run(test.name, func(t *testing.T) {
+			req := buildSearchRequest(t, test.method, test.query)
+			var ds internal.DataSource = testDB
+			if test.ds != nil {
+				ds = test.ds
+			}
+			gotAction, err := determineSearchAction(req, ds, vc)
+			if err != nil {
+				serr, ok := err.(*serverError)
+				if !ok {
+					t.Fatal(err)
+				}
+				if g, w := serr.status, test.wantStatus; g != w {
+					t.Errorf("got status %d, want %d", g, w)
+				}
+				return
+			}
+			if g, w := gotAction.redirectURL, test.wantRedirect; g != w {
+				t.Errorf("redirect:\ngot  %q\nwant %q", g, w)
+			}
+			if g, w := gotAction.template, test.wantTemplate; g != w {
+				t.Errorf("template:\ngot  %q\nwant %q", g, w)
+			}
+		})
+	}
+}
+
+func buildSearchRequest(t *testing.T, method, query string) *http.Request {
+	if method == "" {
+		method = "GET"
+	}
+	u := "/search"
+	if query != "" {
+		u += "?" + query
+	}
+	req, err := http.NewRequest(method, u, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return req
+}
+
 func TestSearchQueryAndMode(t *testing.T) {
 	for _, test := range []struct {
 		name, m, q, wantSearchMode string
@@ -447,15 +596,22 @@
 		},
 	} {
 		t.Run(test.name, func(t *testing.T) {
-			gotPage, gotURL, err := searchVulnAlias(context.Background(), test.mode, test.query, vc)
+			gotAction, err := searchVulnAlias(context.Background(), test.mode, test.query, vc)
 			if (err != nil) != test.wantErr {
 				t.Fatalf("got %v, want error %t", err, test.wantErr)
 			}
-			if !cmp.Equal(gotPage, test.wantPage, cmpopts.IgnoreUnexported(VulnListPage{})) {
-				t.Errorf("page:\ngot  %+v\nwant %+v", gotPage, test.wantPage)
+			var wantAction *searchAction
+			if test.wantURL != "" {
+				wantAction = &searchAction{redirectURL: test.wantURL}
+			} else if test.wantPage != nil {
+				wantAction = &searchAction{
+					title:    test.query + " - Vulnerability Reports",
+					template: "vuln/list",
+					page:     test.wantPage,
+				}
 			}
-			if gotURL != test.wantURL {
-				t.Errorf("redirect: got %q, want %q", gotURL, test.wantURL)
+			if !cmp.Equal(gotAction, wantAction, cmp.AllowUnexported(searchAction{}), cmpopts.IgnoreUnexported(VulnListPage{})) {
+				t.Errorf("\ngot  %+v\nwant %+v", gotAction, wantAction)
 			}
 		})
 	}
diff --git a/internal/frontend/server.go b/internal/frontend/server.go
index 908c048..2ebd73f 100644
--- a/internal/frontend/server.go
+++ b/internal/frontend/server.go
@@ -422,6 +422,8 @@
 	SearchModeSymbol string
 }
 
+func (p *basePage) setBasePage(bp basePage) { *p = bp }
+
 // licensePolicyPage is used to generate the static license policy page.
 type licensePolicyPage struct {
 	basePage