internal/frontend: recycle database connections every 5m

In order to avoid imbalance between pkgsite's two database instances,
recycle connections every 5 minutes.

Change-Id: I9ca1e686a90f8c61619fd76454ec66163e501ee1
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/680175
kokoro-CI: kokoro <noreply+kokoro@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/cmd/frontend/main.go b/cmd/frontend/main.go
index ad946ec..4338fd0 100644
--- a/cmd/frontend/main.go
+++ b/cmd/frontend/main.go
@@ -33,6 +33,7 @@
 	"golang.org/x/pkgsite/internal/proxy"
 	"golang.org/x/pkgsite/internal/queue"
 	"golang.org/x/pkgsite/internal/queue/gcpqueue"
+	"golang.org/x/pkgsite/internal/resource"
 	"golang.org/x/pkgsite/internal/source"
 	"golang.org/x/pkgsite/internal/static"
 	"golang.org/x/pkgsite/internal/trace"
@@ -70,7 +71,7 @@
 	}
 
 	var (
-		dsg        func(context.Context) internal.DataSource
+		dsg        func(context.Context) (internal.DataSource, func())
 		fetchQueue queue.Queue
 	)
 	if *bypassLicenseCheck {
@@ -96,14 +97,19 @@
 			ProxyClientForLatest: proxyClient,
 			BypassLicenseCheck:   *bypassLicenseCheck,
 		}.New()
-		dsg = func(context.Context) internal.DataSource { return ds }
+		dsg = func(context.Context) (internal.DataSource, func()) { return ds, func() {} }
 	} else {
-		db, err := cmdconfig.OpenDB(ctx, cfg, *bypassLicenseCheck)
-		if err != nil {
-			log.Fatalf(ctx, "%v", err)
+		dbResource := resource.New(func() *postgres.DB {
+			db, err := cmdconfig.OpenDB(ctx, cfg, *bypassLicenseCheck)
+			if err != nil {
+				log.Fatalf(ctx, "%v", err)
+			}
+			return db
+		}, 5*time.Minute)
+
+		dsg = func(ctx context.Context) (internal.DataSource, func()) {
+			return dbResource.Get()
 		}
-		defer db.Close()
-		dsg = func(context.Context) internal.DataSource { return db }
 		sourceClient := source.NewClient(&http.Client{
 			Transport: new(ochttp.Transport),
 			Timeout:   config.SourceTimeout,
@@ -113,6 +119,8 @@
 		// per-request connection.
 		fetchQueue, err = gcpqueue.New(ctx, cfg, queueName, *workers, expg,
 			func(ctx context.Context, modulePath, version string) (int, error) {
+				db, release := dbResource.Get()
+				defer release()
 				return fetchserver.FetchAndUpdateState(ctx, modulePath, version, proxyClient, sourceClient, db)
 			})
 		if err != nil {
diff --git a/cmd/internal/pkgsite/server.go b/cmd/internal/pkgsite/server.go
index 7a564de..4b1979d 100644
--- a/cmd/internal/pkgsite/server.go
+++ b/cmd/internal/pkgsite/server.go
@@ -47,6 +47,8 @@
 }
 
 // BuildServer builds a *frontend.Server using the given configuration.
+//
+// TODO(rfindley): move to the cmd/pkgsite package, its only caller.
 func BuildServer(ctx context.Context, serverCfg ServerConfig) (*frontend.Server, error) {
 	if len(serverCfg.Paths) == 0 && !serverCfg.UseCache && serverCfg.Proxy == nil {
 		serverCfg.Paths = []string{"."}
@@ -288,7 +290,7 @@
 	go lds.GetUnitMeta(context.Background(), "", "std", "latest")
 
 	server, err := frontend.NewServer(frontend.ServerConfig{
-		DataSourceGetter: func(context.Context) internal.DataSource { return lds },
+		DataSourceGetter: func(context.Context) (internal.DataSource, func()) { return lds, func() {} },
 		TemplateFS:       template.TrustedFSFromEmbed(static.FS),
 		StaticFS:         staticFS,
 		DevMode:          devMode,
diff --git a/internal/frontend/fetchserver/fetch_test.go b/internal/frontend/fetchserver/fetch_test.go
index 7f05fd5..dce2de1 100644
--- a/internal/frontend/fetchserver/fetch_test.go
+++ b/internal/frontend/fetchserver/fetch_test.go
@@ -62,7 +62,7 @@
 
 	s, err := frontend.NewServer(frontend.ServerConfig{
 		FetchServer:      f,
-		DataSourceGetter: func(context.Context) internal.DataSource { return testDB },
+		DataSourceGetter: func(context.Context) (internal.DataSource, func()) { return testDB, func() {} },
 		Queue:            q,
 		TemplateFS:       template.TrustedFSFromEmbed(static.FS),
 		// Use the embedded FSs here to make sure they're tested.
diff --git a/internal/frontend/frontend_test.go b/internal/frontend/frontend_test.go
index e2b01ba..6dc6b16 100644
--- a/internal/frontend/frontend_test.go
+++ b/internal/frontend/frontend_test.go
@@ -39,7 +39,7 @@
 	t.Helper()
 
 	s, err := NewServer(ServerConfig{
-		DataSourceGetter: func(context.Context) internal.DataSource { return fakedatasource.New() },
+		DataSourceGetter: func(context.Context) (internal.DataSource, func()) { return fakedatasource.New(), func() {} },
 		TemplateFS:       template.TrustedFSFromEmbed(static.FS),
 		// Use the embedded FSs here to make sure they're tested.
 		// Integration tests will use the actual directories.
diff --git a/internal/frontend/latest_version.go b/internal/frontend/latest_version.go
index 8af58fd..72d51b9 100644
--- a/internal/frontend/latest_version.go
+++ b/internal/frontend/latest_version.go
@@ -26,7 +26,8 @@
 
 	// It is okay to use a different DataSource (DB connection) than the rest of the
 	// request, because this makes self-contained calls on the DB.
-	ds := s.getDataSource(ctx)
+	ds, release := s.getDataSource(ctx)
+	defer release()
 
 	latest, err := ds.GetLatestInfo(ctx, unitPath, modulePath, latestUnitMeta)
 	if err != nil {
diff --git a/internal/frontend/latest_version_test.go b/internal/frontend/latest_version_test.go
index be2f2f8..19f4e84 100644
--- a/internal/frontend/latest_version_test.go
+++ b/internal/frontend/latest_version_test.go
@@ -59,7 +59,7 @@
 	}
 	ctx := context.Background()
 	insertTestModules(ctx, t, fds, persistedModules)
-	svr := &Server{getDataSource: func(context.Context) internal.DataSource { return fds }}
+	svr := &Server{getDataSource: func(context.Context) (internal.DataSource, func()) { return fds, func() {} }}
 	for _, tc := range tt {
 		t.Run(tc.name, func(t *testing.T) {
 			got := svr.GetLatestInfo(ctx, tc.fullPath, tc.modulePath, nil)
diff --git a/internal/frontend/server.go b/internal/frontend/server.go
index a5e5e53..3e63f69 100644
--- a/internal/frontend/server.go
+++ b/internal/frontend/server.go
@@ -42,7 +42,7 @@
 type Server struct {
 	fetchServer FetchServerInterface
 	// getDataSource should never be called from a handler. It is called only in Server.errorHandler.
-	getDataSource      func(context.Context) internal.DataSource
+	getDataSource      func(context.Context) (internal.DataSource, func())
 	queue              queue.Queue
 	templateFS         template.TrustedFS
 	staticFS           fs.FS
@@ -82,9 +82,9 @@
 	Config *config.Config
 	// Note that FetchServer may be nil.
 	FetchServer FetchServerInterface
-	// DataSourceGetter should return a DataSource on each call.
+	// DataSourceGetter should return a DataSource and a release function on each call.
 	// It should be goroutine-safe.
-	DataSourceGetter  func(context.Context) internal.DataSource
+	DataSourceGetter  func(context.Context) (internal.DataSource, func())
 	Queue             queue.Queue
 	TemplateFS        template.TrustedFS // for loading templates safely
 	StaticFS          fs.FS              // for static/ directory
@@ -503,7 +503,8 @@
 func (s *Server) errorHandler(f func(w http.ResponseWriter, r *http.Request, ds internal.DataSource) error) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		// Obtain a DataSource to use for this request.
-		ds := s.getDataSource(r.Context())
+		ds, release := s.getDataSource(r.Context())
+		defer release()
 		if err := f(w, r, ds); err != nil {
 			s.serveError(w, r, err)
 		}
diff --git a/internal/resource/resource.go b/internal/resource/resource.go
new file mode 100644
index 0000000..57017d0
--- /dev/null
+++ b/internal/resource/resource.go
@@ -0,0 +1,92 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The resource package defines types to track the lifecycle of transient
+// resources, such as a database connection, that should be renewed at some
+// fixed interval.
+package resource
+
+import (
+	"io"
+	"sync"
+	"sync/atomic"
+	"time"
+)
+
+type instance[T io.Closer] struct {
+	refs atomic.Int64
+	v    T
+}
+
+func (i *instance[T]) acquire(initial bool) (T, func()) {
+	if i.refs.Add(1) <= 1 && !initial {
+		panic("acquire on released instance")
+	}
+	return i.v, func() {
+		if i.refs.Add(-1) == 0 {
+			i.v.Close() // ignore error
+			var zero T
+			i.v = zero // aid GC
+		}
+	}
+}
+
+// A Resource is a container for a transient resource of type T that should be
+// periodically renewed, such as a database connection.
+//
+// Its Get method returns an instance of the resource, along with a release
+// function that the caller must invoke when it is done with the resource.
+//
+// The first call to Get creates a new resource instance. This instance is
+// cached and returned by subsequent calls to Get for a fixed duration. After
+// this duration expires, the next call to Get will create a new instance. The
+// expired instance is closed once all its users have released it.
+//
+// A Resource is safe for concurrent use.
+type Resource[T io.Closer] struct {
+	get      func() T
+	validFor time.Duration
+	after    func(func()) // for testing; fakes time.AfterFunc(validFor, f)
+
+	mu  sync.Mutex
+	cur *instance[T]
+}
+
+// New creates a new Resource that is valid for the given duration. The get
+// function is called to create a new resource instance when one is needed.
+func New[T io.Closer](get func() T, validFor time.Duration) *Resource[T] {
+	r := newAfter(get, func(f func()) {
+		time.AfterFunc(validFor, f)
+	})
+	r.validFor = validFor
+	return r
+}
+
+// newAfter returns a new resource with a fake after func, for testing.
+func newAfter[T io.Closer](get func() T, after func(func())) *Resource[T] {
+	return &Resource[T]{get: get, after: after}
+}
+
+// Get returns the current resource instance and a function to release it.
+// The release function must be called when the caller is done with the
+// resource.
+func (r *Resource[T]) Get() (T, func()) {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.cur == nil {
+		r.cur = &instance[T]{v: r.get()}
+		// Acquire one ref for the new instance that lasts the duration.
+		//
+		// This is distinct from the ref acquired below; it ensures that the
+		// resource is not released until the duration has expired.
+		_, release := r.cur.acquire(true)
+		r.after(func() {
+			r.mu.Lock()
+			defer r.mu.Unlock()
+			release()
+			r.cur = nil
+		})
+	}
+	return r.cur.acquire(false)
+}
diff --git a/internal/resource/resource_test.go b/internal/resource/resource_test.go
new file mode 100644
index 0000000..6791a39
--- /dev/null
+++ b/internal/resource/resource_test.go
@@ -0,0 +1,188 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package resource
+
+import (
+	"slices"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+)
+
+type fake struct {
+	id     int64
+	closed bool
+	mu     sync.Mutex
+}
+
+func (f *fake) Close() error {
+	f.mu.Lock()
+	defer f.mu.Unlock()
+	if f.closed {
+		panic("duplicate close")
+	}
+	f.closed = true
+	return nil
+}
+
+func (f *fake) isClosed() bool {
+	f.mu.Lock()
+	defer f.mu.Unlock()
+	return f.closed
+}
+
+// fakeTimer allows manual control over time-based events.
+type fakeTimer struct {
+	mu sync.Mutex
+	fs []func()
+}
+
+func newFakeTimer() *fakeTimer {
+	return &fakeTimer{}
+}
+
+func (t *fakeTimer) after(f func()) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	t.fs = append(t.fs, f)
+}
+
+func (t *fakeTimer) advance(tt *testing.T) {
+	tt.Helper()
+	t.mu.Lock()
+	fs := slices.Clone(t.fs)
+	t.fs = nil
+	t.mu.Unlock()
+	if len(fs) == 0 {
+		tt.Fatal("timer did not fire")
+	}
+	for _, f := range fs {
+		f()
+	}
+	t.fs = nil
+}
+
+func TestResource_Reuse(t *testing.T) {
+	var nextID atomic.Int64
+	get := func() *fake {
+		return &fake{id: nextID.Add(1)}
+	}
+	timer := newFakeTimer()
+	r := newAfter(get, timer.after)
+
+	f1, release1 := r.Get()
+	if f1.id != 1 {
+		t.Fatalf("f1.id = %d, want 1", f1.id)
+	}
+
+	f2, release2 := r.Get()
+	if f2.id != 1 {
+		t.Fatalf("f2.id = %d, want 1", f2.id)
+	}
+
+	release1()
+	if f1.isClosed() {
+		t.Fatal("f1 closed, want not closed")
+	}
+	release2()
+	if f1.isClosed() {
+		t.Fatal("f1 closed, want not closed")
+	}
+
+	// The resource holds its own reference, which is released by the timer.
+	timer.advance(t)
+
+	// Now all references are released, it should be closed.
+	if !f1.isClosed() {
+		t.Fatal("f1 not closed, want closed")
+	}
+}
+
+func TestResource_Expire(t *testing.T) {
+	var nextID atomic.Int64
+	get := func() *fake {
+		return &fake{id: nextID.Add(1)}
+	}
+	timer := newFakeTimer()
+	r := newAfter(get, timer.after)
+
+	f1, release1 := r.Get()
+	if f1.id != 1 {
+		t.Fatalf("f1.id = %d, want 1", f1.id)
+	}
+	release1() // Release our hold on it.
+
+	// Advance time, causing the resource's internal reference to be released.
+	timer.advance(t)
+
+	if !f1.isClosed() {
+		t.Fatal("f1 not closed, want closed")
+	}
+
+	f2, release2 := r.Get()
+	if f2.id != 2 {
+		t.Fatalf("f2.id = %d, want 2", f2.id)
+	}
+	release2()
+
+	timer.advance(t)
+	if !f2.isClosed() {
+		t.Fatal("f2 not closed, want closed")
+	}
+}
+
+func TestResource_Concurrent(t *testing.T) {
+	var nextID atomic.Int64
+	get := func() *fake {
+		return &fake{id: nextID.Add(1)}
+	}
+	timer := newFakeTimer()
+	r := newAfter(get, timer.after)
+
+	// Get the first resource so we have a handle to it.
+	f1, release1 := r.Get()
+	if f1.id != 1 {
+		t.Fatalf("f1.id = %d, want 1", f1.id)
+	}
+
+	var wg sync.WaitGroup
+	for range 10 {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			f, release := r.Get()
+			if f.id != 1 {
+				t.Errorf("got id %d, want 1", f.id)
+			}
+			// Hold the resource for a bit to create contention.
+			time.Sleep(1 * time.Millisecond)
+			release()
+		}()
+	}
+	wg.Wait()
+
+	// All goroutines have released. Now we release our initial hold.
+	release1()
+
+	// At this point, only the resource's own reference remains.
+	if f1.isClosed() {
+		t.Fatal("f1 closed, want not closed")
+	}
+
+	// Advance time to release the final reference.
+	timer.advance(t)
+
+	if !f1.isClosed() {
+		t.Fatal("f1 not closed, want closed")
+	}
+
+	// Getting a new resource should give a new ID.
+	f2, release2 := r.Get()
+	if f2.id != 2 {
+		t.Fatalf("f2.id = %d, want 2", f2.id)
+	}
+	release2()
+}
diff --git a/internal/testing/integration/frontend_test.go b/internal/testing/integration/frontend_test.go
index e9cbe99..1d78756 100644
--- a/internal/testing/integration/frontend_test.go
+++ b/internal/testing/integration/frontend_test.go
@@ -40,7 +40,7 @@
 	}
 	s, err := frontend.NewServer(frontend.ServerConfig{
 		FetchServer:      fs,
-		DataSourceGetter: func(context.Context) internal.DataSource { return testDB },
+		DataSourceGetter: func(context.Context) (internal.DataSource, func()) { return testDB, func() {} },
 		TemplateFS:       template.TrustedFSFromTrustedSource(template.TrustedSourceFromConstant(staticDir)),
 		StaticFS:         os.DirFS(staticDir),
 		ThirdPartyFS:     os.DirFS("../../../third_party"),