blob: 150c79af8528b673e45a34e55b8aa9a22a56f8a6 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package memoize defines a "promise" abstraction that enables
// memoization of the result of calling an expensive but idempotent
// function.
//
// Call p = NewPromise(f) to obtain a promise for the future result of
// calling f(), and call p.Get() to obtain that result. All calls to
// p.Get return the result of a single call of f().
// Get blocks if the function has not finished (or started).
//
// A Store is a map of arbitrary keys to promises. Use Store.Promise
// to create a promise in the store. All calls to Handle(k) return the
// same promise as long as it is in the store. These promises are
// reference-counted and must be explicitly released. Once the last
// reference is released, the promise is removed from the store.
package memoize
import (
"context"
"fmt"
"reflect"
"runtime/trace"
"sync"
"sync/atomic"
"golang.org/x/tools/internal/xcontext"
)
// Function is the type of a function that can be memoized.
//
// If the arg is a RefCounted, its Acquire/Release operations are called.
//
// The argument must not materially affect the result of the function
// in ways that are not captured by the promise's key, since if
// Promise.Get is called twice concurrently, with the same (implicit)
// key but different arguments, the Function is called only once but
// its result must be suitable for both callers.
//
// The main purpose of the argument is to avoid the Function closure
// needing to retain large objects (in practice: the snapshot) in
// memory that can be supplied at call time by any caller.
type Function func(ctx context.Context, arg interface{}) interface{}
// A RefCounted is a value whose functional lifetime is determined by
// reference counting.
//
// Its Acquire method is called before the Function is invoked, and
// the corresponding release is called when the Function returns.
// Usually both events happen within a single call to Get, so Get
// would be fine with a "borrowed" reference, but if the context is
// cancelled, Get may return before the Function is complete, causing
// the argument to escape, and potential premature destruction of the
// value. For a reference-counted type, this requires a pair of
// increment/decrement operations to extend its life.
type RefCounted interface {
// Acquire prevents the value from being destroyed until the
// returned function is called.
Acquire() func()
}
// A Promise represents the future result of a call to a function.
type Promise struct {
debug string // for observability
mu sync.Mutex
// A Promise starts out IDLE, waiting for something to demand
// its evaluation. It then transitions into RUNNING state.
//
// While RUNNING, waiters tracks the number of Get calls
// waiting for a result, and the done channel is used to
// notify waiters of the next state transition. Once
// evaluation finishes, value is set, state changes to
// COMPLETED, and done is closed, unblocking waiters.
//
// Alternatively, as Get calls are cancelled, they decrement
// waiters. If it drops to zero, the inner context is
// cancelled, computation is abandoned, and state resets to
// IDLE to start the process over again.
state state
// done is set in running state, and closed when exiting it.
done chan struct{}
// cancel is set in running state. It cancels computation.
cancel context.CancelFunc
// waiters is the number of Gets outstanding.
waiters uint
// the function that will be used to populate the value
function Function
// value is set in completed state.
value interface{}
refcount int32 // accessed using atomic load/store
}
// NewPromise returns a promise for the future result of calling the
// specified function.
//
// The debug string is used to classify promises in logs and metrics.
// It should be drawn from a small set.
func NewPromise(debug string, function Function) *Promise {
if function == nil {
panic("nil function")
}
return &Promise{
debug: debug,
function: function,
}
}
type state int
const (
stateIdle = iota // newly constructed, or last waiter was cancelled
stateRunning // start was called and not cancelled
stateCompleted // function call ran to completion
)
// Cached returns the value associated with a promise.
//
// It will never cause the value to be generated.
// It will return the cached value, if present.
func (p *Promise) Cached() interface{} {
p.mu.Lock()
defer p.mu.Unlock()
if p.state == stateCompleted {
return p.value
}
return nil
}
// Get returns the value associated with a promise.
//
// All calls to Promise.Get on a given promise return the
// same result but the function is called (to completion) at most once.
//
// If the value is not yet ready, the underlying function will be invoked.
//
// If ctx is cancelled, Get returns (nil, Canceled).
// If all concurrent calls to Get are cancelled, the context provided
// to the function is cancelled. A later call to Get may attempt to
// call the function again.
func (p *Promise) Get(ctx context.Context, arg interface{}) (interface{}, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
p.mu.Lock()
switch p.state {
case stateIdle:
return p.run(ctx, arg)
case stateRunning:
return p.wait(ctx)
case stateCompleted:
defer p.mu.Unlock()
return p.value, nil
default:
panic("unknown state")
}
}
// run starts p.function and returns the result. p.mu must be locked.
func (p *Promise) run(ctx context.Context, arg interface{}) (interface{}, error) {
childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
p.cancel = cancel
p.state = stateRunning
p.done = make(chan struct{})
function := p.function // Read under the lock
// Make sure that the argument isn't destroyed while we're running in it.
release := func() {}
if rc, ok := arg.(RefCounted); ok {
release = rc.Acquire()
}
go func() {
trace.WithRegion(childCtx, fmt.Sprintf("Promise.run %s", p.debug), func() {
defer release()
// Just in case the function does something expensive without checking
// the context, double-check we're still alive.
if childCtx.Err() != nil {
return
}
v := function(childCtx, arg)
if childCtx.Err() != nil {
return
}
p.mu.Lock()
defer p.mu.Unlock()
// It's theoretically possible that the promise has been cancelled out
// of the run that started us, and then started running again since we
// checked childCtx above. Even so, that should be harmless, since each
// run should produce the same results.
if p.state != stateRunning {
return
}
p.value = v
p.function = nil // aid GC
p.state = stateCompleted
close(p.done)
})
}()
return p.wait(ctx)
}
// wait waits for the value to be computed, or ctx to be cancelled. p.mu must be locked.
func (p *Promise) wait(ctx context.Context) (interface{}, error) {
p.waiters++
done := p.done
p.mu.Unlock()
select {
case <-done:
p.mu.Lock()
defer p.mu.Unlock()
if p.state == stateCompleted {
return p.value, nil
}
return nil, nil
case <-ctx.Done():
p.mu.Lock()
defer p.mu.Unlock()
p.waiters--
if p.waiters == 0 && p.state == stateRunning {
p.cancel()
close(p.done)
p.state = stateIdle
p.done = nil
p.cancel = nil
}
return nil, ctx.Err()
}
}
// An EvictionPolicy controls the eviction behavior of keys in a Store when
// they no longer have any references.
type EvictionPolicy int
const (
// ImmediatelyEvict evicts keys as soon as they no longer have references.
ImmediatelyEvict EvictionPolicy = iota
// NeverEvict does not evict keys.
NeverEvict
)
// A Store maps arbitrary keys to reference-counted promises.
//
// The zero value is a valid Store, though a store may also be created via
// NewStore if a custom EvictionPolicy is required.
type Store struct {
evictionPolicy EvictionPolicy
promisesMu sync.Mutex
promises map[interface{}]*Promise
}
// NewStore creates a new store with the given eviction policy.
func NewStore(policy EvictionPolicy) *Store {
return &Store{evictionPolicy: policy}
}
// Promise returns a reference-counted promise for the future result of
// calling the specified function.
//
// Calls to Promise with the same key return the same promise,
// incrementing its reference count. The caller must call the
// returned function to decrement the promise's reference count when
// it is no longer needed. Once the last reference has been released,
// the promise is removed from the store.
func (store *Store) Promise(key interface{}, function Function) (*Promise, func()) {
store.promisesMu.Lock()
p, ok := store.promises[key]
if !ok {
p = NewPromise(reflect.TypeOf(key).String(), function)
if store.promises == nil {
store.promises = map[interface{}]*Promise{}
}
store.promises[key] = p
}
atomic.AddInt32(&p.refcount, 1)
store.promisesMu.Unlock()
release := func() {
// TODO(rfindley): this looks racy: it's possible that the refcount is
// incremented before we grab the lock.
if atomic.AddInt32(&p.refcount, -1) == 0 && store.evictionPolicy != NeverEvict {
store.promisesMu.Lock()
delete(store.promises, key)
store.promisesMu.Unlock()
}
}
return p, release
}
// Stats returns the number of each type of key in the store.
func (s *Store) Stats() map[reflect.Type]int {
result := map[reflect.Type]int{}
s.promisesMu.Lock()
defer s.promisesMu.Unlock()
for k := range s.promises {
result[reflect.TypeOf(k)]++
}
return result
}
// DebugOnlyIterate iterates through the store and, for each completed
// promise, calls f(k, v) for the map key k and function result v. It
// should only be used for debugging purposes.
func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
s.promisesMu.Lock()
defer s.promisesMu.Unlock()
for k, p := range s.promises {
if v := p.Cached(); v != nil {
f(k, v)
}
}
}