internal/postgres: add GetVersionMapsNon2xxStatus

GetVersionMapsWithNon2xxStatus is added, which returns the version_maps
correlating to a given path with a 4xx or 5xx status.

Change-Id: I99d52bbc38183939331e2d991a3107635036ea5b
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/281673
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/postgres/version_map.go b/internal/postgres/version_map.go
index 73a2348..8fdaa4c 100644
--- a/internal/postgres/version_map.go
+++ b/internal/postgres/version_map.go
@@ -9,6 +9,8 @@
 	"database/sql"
 	"fmt"
 
+	"github.com/Masterminds/squirrel"
+	"github.com/lib/pq"
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/derrors"
 	"golang.org/x/pkgsite/internal/version"
@@ -71,22 +73,15 @@
 		return nil, fmt.Errorf("modulePath must be specified: %w", derrors.InvalidArgument)
 	}
 
-	query := `
-		SELECT
-			module_path,
-			requested_version,
-			resolved_version,
-			go_mod_path,
-			status,
-			error,
-			updated_at
-		FROM
-			version_map
-		WHERE
-			module_path=$1
-			AND requested_version=$2;`
+	q, args, err := versionMapSelect().
+		Where(squirrel.Eq{"module_path": modulePath}).
+		Where(squirrel.Eq{"requested_version": requestedVersion}).
+		PlaceholderFormat(squirrel.Dollar).ToSql()
+	if err != nil {
+		return nil, err
+	}
 	var vm internal.VersionMap
-	err = db.db.QueryRow(ctx, query, modulePath, requestedVersion).Scan(
+	err = db.db.QueryRow(ctx, q, args...).Scan(
 		&vm.ModulePath, &vm.RequestedVersion, &vm.ResolvedVersion, &vm.GoModPath,
 		&vm.Status, &vm.Error, &vm.UpdatedAt)
 	switch err {
@@ -98,3 +93,50 @@
 		return nil, err
 	}
 }
+
+// GetVersionMapsNon2xxStatus returns all of the version maps for the provided
+// path and requested version if they are present.
+func (db *DB) GetVersionMapsNon2xxStatus(ctx context.Context, paths []string, requestedVersion string) (_ []*internal.VersionMap, err error) {
+	defer derrors.Wrap(&err, "DB.GetVersionMapsWith4xxStatus(ctx, %v, %q)", paths, requestedVersion)
+
+	var result []*internal.VersionMap
+	versionMaps := map[string]*internal.VersionMap{}
+	collect := func(rows *sql.Rows) error {
+		var vm internal.VersionMap
+		if err := rows.Scan(
+			&vm.ModulePath, &vm.RequestedVersion, &vm.ResolvedVersion, &vm.GoModPath,
+			&vm.Status, &vm.Error, &vm.UpdatedAt); err != nil {
+			return err
+		}
+		if _, ok := versionMaps[vm.ModulePath]; !ok {
+			versionMaps[vm.ModulePath] = &vm
+			result = append(result, &vm)
+		}
+		return nil
+	}
+	q, args, err := versionMapSelect().
+		Where("module_path = ANY(?)", pq.Array(paths)).
+		Where(squirrel.Or{squirrel.Eq{"requested_version": requestedVersion}, squirrel.Eq{"resolved_version": requestedVersion}}).
+		Where(squirrel.GtOrEq{"status": 400}).
+		OrderBy("module_path DESC").
+		PlaceholderFormat(squirrel.Dollar).ToSql()
+	if err != nil {
+		return nil, fmt.Errorf("squirrel.ToSql: %v", err)
+	}
+	if err := db.db.RunQuery(ctx, q, collect, args...); err != nil {
+		return nil, err
+	}
+	return result, nil
+}
+
+func versionMapSelect() squirrel.SelectBuilder {
+	return squirrel.Select(
+		"module_path",
+		"requested_version",
+		"resolved_version",
+		"go_mod_path",
+		"status",
+		"error",
+		"updated_at",
+	).From("version_map")
+}
diff --git a/internal/postgres/version_map_test.go b/internal/postgres/version_map_test.go
index 6869fcd..3e50637 100644
--- a/internal/postgres/version_map_test.go
+++ b/internal/postgres/version_map_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"context"
+	"fmt"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
@@ -76,3 +77,51 @@
 	vm.Status = 200
 	upsertAndVerifyVersionMap(vm)
 }
+
+func TestGetVersionMapsWithNon2xxStatus(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
+	defer cancel()
+	defer ResetTestDB(testDB, t)
+
+	tests := []struct {
+		path   string
+		status int
+	}{
+		{"github.com/a/b", 200},
+		{"github.com/a/c", 290},
+		{"github.com/a/d", 400},
+		{"github.com/a/e", 440},
+		{"github.com/a/f", 490},
+		{"github.com/a/g", 491},
+		{"github.com/a/h", 500},
+	}
+	var paths []string
+	want := map[string]bool{}
+	for _, test := range tests {
+		paths = append(paths, test.path)
+		if test.status >= 400 {
+			want[test.path] = true
+		}
+		if err := testDB.UpsertVersionMap(ctx, &internal.VersionMap{
+			ModulePath:       test.path,
+			RequestedVersion: internal.LatestVersion,
+			ResolvedVersion:  sample.VersionString,
+			GoModPath:        test.path,
+			Status:           test.status,
+		}); err != nil {
+			t.Fatal(err)
+		}
+	}
+	vms, err := testDB.GetVersionMapsNon2xxStatus(ctx, paths, internal.LatestVersion)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	got := map[string]bool{}
+	for _, vm := range vms {
+		got[vm.ModulePath] = true
+	}
+	if fmt.Sprint(want) != fmt.Sprint(got) {
+		t.Fatalf("got = \n%v\nwant =\n%v", got, want)
+	}
+}