| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package cache |
| |
| import ( |
| "context" |
| "fmt" |
| "sync/atomic" |
| "testing" |
| "time" |
| |
| "golang.org/x/sync/errgroup" |
| ) |
| |
| func TestFutureCache_Persistent(t *testing.T) { |
| c := newFutureCache[int, int](true) |
| ctx := context.Background() |
| |
| var computed atomic.Int32 |
| compute := func(i int) cacheFunc[int] { |
| return func(context.Context) (int, error) { |
| computed.Add(1) |
| return i, ctx.Err() |
| } |
| } |
| |
| testFutureCache(t, ctx, c, compute) |
| |
| // Since this cache is persistent, we should get exactly 10 computations, |
| // since there are 10 distinct keys in [testFutureCache]. |
| if got := computed.Load(); got != 10 { |
| t.Errorf("computed %d times, want 10", got) |
| } |
| } |
| |
| func TestFutureCache_Ephemeral(t *testing.T) { |
| c := newFutureCache[int, int](false) |
| ctx := context.Background() |
| |
| var computed atomic.Int32 |
| compute := func(i int) cacheFunc[int] { |
| return func(context.Context) (int, error) { |
| time.Sleep(1 * time.Millisecond) |
| computed.Add(1) |
| return i, ctx.Err() |
| } |
| } |
| |
| testFutureCache(t, ctx, c, compute) |
| |
| // Since this cache is ephemeral, we should get at least 30 computations, |
| // since there are 10 distinct keys and three synchronous passes in |
| // [testFutureCache]. |
| if got := computed.Load(); got < 30 { |
| t.Errorf("computed %d times, want at least 30", got) |
| } else { |
| t.Logf("compute ran %d times", got) |
| } |
| } |
| |
| // testFutureCache starts 100 goroutines concurrently, indexed by j, each |
| // getting key j%10 from the cache. It repeats this three times, synchronizing |
| // after each. |
| // |
| // This is designed to exercise both concurrent and synchronous access to the |
| // cache. |
| func testFutureCache(t *testing.T, ctx context.Context, c *futureCache[int, int], compute func(int) cacheFunc[int]) { |
| for range 3 { |
| var g errgroup.Group |
| for j := range 100 { |
| mod := j % 10 |
| compute := compute(mod) |
| g.Go(func() error { |
| got, err := c.get(ctx, mod, compute) |
| if err == nil && got != mod { |
| t.Errorf("get() = %d, want %d", got, mod) |
| } |
| return err |
| }) |
| } |
| if err := g.Wait(); err != nil { |
| t.Fatal(err) |
| } |
| } |
| } |
| |
| func TestFutureCache_Retrying(t *testing.T) { |
| // This test verifies the retry behavior of cache entries, |
| // by checking that cancelled work is handed off to the next awaiter. |
| // |
| // The setup is a little tricky: 10 goroutines are started, and the first 9 |
| // are cancelled whereas the 10th is allowed to finish. As a result, the |
| // computation should always succeed with value 9. |
| |
| ctx := context.Background() |
| |
| for _, persistent := range []bool{true, false} { |
| t.Run(fmt.Sprintf("persistent=%t", persistent), func(t *testing.T) { |
| c := newFutureCache[int, int](persistent) |
| |
| var started atomic.Int32 |
| |
| // compute returns a new cacheFunc that produces the value i, after the |
| // provided done channel is closed. |
| compute := func(i int, done <-chan struct{}) cacheFunc[int] { |
| return func(ctx context.Context) (int, error) { |
| started.Add(1) |
| select { |
| case <-ctx.Done(): |
| return 0, ctx.Err() |
| case <-done: |
| return i, nil |
| } |
| } |
| } |
| |
| // goroutines are either cancelled, or allowed to complete, |
| // as controlled by cancels and dones. |
| var ( |
| cancels = make([]func(), 10) |
| dones = make([]chan struct{}, 10) |
| ) |
| |
| var g errgroup.Group |
| var lastValue atomic.Int32 // keep track of the last successfully computed value |
| for i := range 10 { |
| ctx, cancel := context.WithCancel(ctx) |
| done := make(chan struct{}) |
| cancels[i] = cancel |
| dones[i] = done |
| compute := compute(i, done) |
| g.Go(func() error { |
| v, err := c.get(ctx, 0, compute) |
| if err == nil { |
| lastValue.Store(int32(v)) |
| } |
| return nil |
| }) |
| } |
| for _, cancel := range cancels[:9] { |
| cancel() |
| } |
| defer cancels[9]() |
| |
| dones[9] <- struct{}{} |
| g.Wait() |
| |
| t.Logf("started %d computations", started.Load()) |
| if got := lastValue.Load(); got != 9 { |
| t.Errorf("after cancelling computation 0-8, got %d, want 9", got) |
| } |
| }) |
| } |
| } |