blob: 1b06d90bb770ab1a8a40c5418f0672c279adc7c4 [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"testing"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/testing/sample"
)
func TestGetLatestMajorPathForV1Path(t *testing.T) {
t.Parallel()
ctx := context.Background()
for _, test := range []struct {
name string
v1ModulePath string
modvers []string
wantModulePath string
wantVersion int
}{
{
"want highest major version",
"m.com",
[]string{"m.com@v1.0.0", "m.com/v2@v2.0.0", "m.com/v11@v11.0.0"},
"m.com/v11", 11,
},
{
"only v1 version",
"m.com",
[]string{"m.com@v1.0.0"},
"m.com", 1,
},
{
"no v1 version",
"m.com",
[]string{"m.com/v4@v4.0.0"},
"m.com/v4", 4,
},
{
"gopkg.in",
"gopkg.in/yaml",
[]string{"gopkg.in/yaml.v1@v1.0.0", "gopkg.in/yaml.v2@v2.0.0"},
"gopkg.in/yaml.v2", 2,
},
} {
t.Run(test.name, func(t *testing.T) {
testDB, release := acquire(t)
defer release()
const suffix = "a/b/c"
check := func(v1path, wantPath string) {
t.Helper()
gotPath, gotVer, err := testDB.GetLatestMajorPathForV1Path(ctx, v1path)
if err != nil {
t.Fatal(err)
}
if gotPath != wantPath || gotVer != test.wantVersion {
t.Errorf("GetLatestMajorPathForV1Path(%q) = %q, %d, want %q, %d",
v1path, gotPath, gotVer, wantPath, test.wantVersion)
}
}
for _, mv := range test.modvers {
mod, ver, _ := parseModuleVersionPackage(mv)
m := sample.Module(mod, ver, suffix)
MustInsertModule(ctx, t, testDB, m)
}
t.Run("module", func(t *testing.T) {
check(test.v1ModulePath, test.wantModulePath)
})
t.Run("package", func(t *testing.T) {
check(test.v1ModulePath+"/"+suffix, test.wantModulePath+"/"+suffix)
})
})
}
}
func TestUpsertPathConcurrently(t *testing.T) {
// Verify that we get no constraint violations or other errors when
// the same path is upserted multiple times concurrently.
t.Parallel()
testDB, release := acquire(t)
defer release()
ctx := context.Background()
const n = 10
errc := make(chan error, n)
for i := 0; i < n; i++ {
go func() {
errc <- testDB.db.Transact(ctx, sql.LevelRepeatableRead, func(tx *database.DB) error {
id, err := upsertPath(ctx, tx, "a/path")
if err != nil {
return err
}
if id == 0 {
return errors.New("zero id")
}
return nil
})
}()
}
for i := 0; i < n; i++ {
if err := <-errc; err != nil {
t.Fatal(err)
}
}
}
func TestUpsertPaths(t *testing.T) {
t.Parallel()
testDB, release := acquire(t)
defer release()
ctx := context.Background()
check := func(paths []string) {
got, err := upsertPathsInTx(ctx, testDB.db, paths)
if err != nil {
t.Fatal(err)
}
checkPathMap(t, got, paths)
}
check([]string{"a", "b", "c"})
check([]string{"b", "c", "d", "e"})
}
func checkPathMap(t *testing.T, got map[string]int, paths []string) {
t.Helper()
if g, w := len(got), len(paths); g != w {
t.Errorf("got %d paths, want %d", g, w)
return
}
for _, p := range paths {
g, ok := got[p]
if !ok {
t.Errorf("missing path %q", p)
} else if g == 0 {
t.Errorf("path %q has a 0 ID", p)
}
}
}
func TestUpsertPathsConcurrently(t *testing.T) {
// Verify that we get no constraint violations or other errors when
// the same set of paths is upserted multiple times concurrently.
t.Parallel()
testDB, release := acquire(t)
defer release()
ctx := context.Background()
const n = 10
paths := make([]string, 100)
for i := 0; i < len(paths); i++ {
paths[i] = fmt.Sprintf("p%d", i)
}
errc := make(chan error, n)
for i := 0; i < n; i++ {
i := i
go func() {
start := (10 * i) % len(paths)
end := start + 50
if end > len(paths) {
end = len(paths)
}
sub := paths[start:end]
got, err := upsertPathsInTx(ctx, testDB.db, sub)
if err == nil {
checkPathMap(t, got, sub)
}
errc <- err
}()
}
for i := 0; i < n; i++ {
if err := <-errc; err != nil {
t.Fatal(err)
}
}
}
func upsertPathsInTx(ctx context.Context, db *database.DB, paths []string) (map[string]int, error) {
var m map[string]int
err := db.Transact(ctx, sql.LevelRepeatableRead, func(tx *database.DB) error {
var err error
m, err = upsertPaths(ctx, tx, paths)
return err
})
if err != nil {
return nil, err
}
return m, nil
}