blob: d96dc0f5317402f0475b6ec9d613f56130a003cd [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cache
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
"golang.org/x/sync/errgroup"
)
func TestFutureCache_Persistent(t *testing.T) {
c := newFutureCache[int, int](true)
ctx := context.Background()
var computed atomic.Int32
compute := func(i int) cacheFunc[int] {
return func(context.Context) (int, error) {
computed.Add(1)
return i, ctx.Err()
}
}
testFutureCache(t, ctx, c, compute)
// Since this cache is persistent, we should get exactly 10 computations,
// since there are 10 distinct keys in [testFutureCache].
if got := computed.Load(); got != 10 {
t.Errorf("computed %d times, want 10", got)
}
}
func TestFutureCache_Ephemeral(t *testing.T) {
c := newFutureCache[int, int](false)
ctx := context.Background()
var computed atomic.Int32
compute := func(i int) cacheFunc[int] {
return func(context.Context) (int, error) {
time.Sleep(1 * time.Millisecond)
computed.Add(1)
return i, ctx.Err()
}
}
testFutureCache(t, ctx, c, compute)
// Since this cache is ephemeral, we should get at least 30 computations,
// since there are 10 distinct keys and three synchronous passes in
// [testFutureCache].
if got := computed.Load(); got < 30 {
t.Errorf("computed %d times, want at least 30", got)
} else {
t.Logf("compute ran %d times", got)
}
}
// testFutureCache starts 100 goroutines concurrently, indexed by j, each
// getting key j%10 from the cache. It repeats this three times, synchronizing
// after each.
//
// This is designed to exercise both concurrent and synchronous access to the
// cache.
func testFutureCache(t *testing.T, ctx context.Context, c *futureCache[int, int], compute func(int) cacheFunc[int]) {
for range 3 {
var g errgroup.Group
for j := range 100 {
mod := j % 10
compute := compute(mod)
g.Go(func() error {
got, err := c.get(ctx, mod, compute)
if err == nil && got != mod {
t.Errorf("get() = %d, want %d", got, mod)
}
return err
})
}
if err := g.Wait(); err != nil {
t.Fatal(err)
}
}
}
func TestFutureCache_Retrying(t *testing.T) {
// This test verifies the retry behavior of cache entries,
// by checking that cancelled work is handed off to the next awaiter.
//
// The setup is a little tricky: 10 goroutines are started, and the first 9
// are cancelled whereas the 10th is allowed to finish. As a result, the
// computation should always succeed with value 9.
ctx := context.Background()
for _, persistent := range []bool{true, false} {
t.Run(fmt.Sprintf("persistent=%t", persistent), func(t *testing.T) {
c := newFutureCache[int, int](persistent)
var started atomic.Int32
// compute returns a new cacheFunc that produces the value i, after the
// provided done channel is closed.
compute := func(i int, done <-chan struct{}) cacheFunc[int] {
return func(ctx context.Context) (int, error) {
started.Add(1)
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-done:
return i, nil
}
}
}
// goroutines are either cancelled, or allowed to complete,
// as controlled by cancels and dones.
var (
cancels = make([]func(), 10)
dones = make([]chan struct{}, 10)
)
var g errgroup.Group
var lastValue atomic.Int32 // keep track of the last successfully computed value
for i := range 10 {
ctx, cancel := context.WithCancel(ctx)
done := make(chan struct{})
cancels[i] = cancel
dones[i] = done
compute := compute(i, done)
g.Go(func() error {
v, err := c.get(ctx, 0, compute)
if err == nil {
lastValue.Store(int32(v))
}
return nil
})
}
for _, cancel := range cancels[:9] {
cancel()
}
defer cancels[9]()
dones[9] <- struct{}{}
g.Wait()
t.Logf("started %d computations", started.Load())
if got := lastValue.Load(); got != 9 {
t.Errorf("after cancelling computation 0-8, got %d, want 9", got)
}
})
}
}