| // Copyright 2023 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package context_test |
| |
| import ( |
| "context" |
| "sync" |
| "testing" |
| "time" |
| ) |
| |
| // afterFuncContext is a context that's not one of the types |
| // defined in context.go, that supports registering AfterFuncs. |
| type afterFuncContext struct { |
| mu sync.Mutex |
| afterFuncs map[*byte]func() |
| done chan struct{} |
| err error |
| } |
| |
| func newAfterFuncContext() context.Context { |
| return &afterFuncContext{} |
| } |
| |
| func (c *afterFuncContext) Deadline() (time.Time, bool) { |
| return time.Time{}, false |
| } |
| |
| func (c *afterFuncContext) Done() <-chan struct{} { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| if c.done == nil { |
| c.done = make(chan struct{}) |
| } |
| return c.done |
| } |
| |
| func (c *afterFuncContext) Err() error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| return c.err |
| } |
| |
| func (c *afterFuncContext) Value(key any) any { |
| return nil |
| } |
| |
| func (c *afterFuncContext) AfterFunc(f func()) func() bool { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| k := new(byte) |
| if c.afterFuncs == nil { |
| c.afterFuncs = make(map[*byte]func()) |
| } |
| c.afterFuncs[k] = f |
| return func() bool { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| _, ok := c.afterFuncs[k] |
| delete(c.afterFuncs, k) |
| return ok |
| } |
| } |
| |
| func (c *afterFuncContext) cancel(err error) { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| if c.err != nil { |
| return |
| } |
| c.err = err |
| for _, f := range c.afterFuncs { |
| go f() |
| } |
| c.afterFuncs = nil |
| } |
| |
| func TestCustomContextAfterFuncCancel(t *testing.T) { |
| ctx0 := &afterFuncContext{} |
| ctx1, cancel := context.WithCancel(ctx0) |
| defer cancel() |
| ctx0.cancel(context.Canceled) |
| <-ctx1.Done() |
| } |
| |
| func TestCustomContextAfterFuncTimeout(t *testing.T) { |
| ctx0 := &afterFuncContext{} |
| ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration) |
| defer cancel() |
| ctx0.cancel(context.Canceled) |
| <-ctx1.Done() |
| } |
| |
| func TestCustomContextAfterFuncAfterFunc(t *testing.T) { |
| ctx0 := &afterFuncContext{} |
| donec := make(chan struct{}) |
| stop := context.AfterFunc(ctx0, func() { |
| close(donec) |
| }) |
| defer stop() |
| ctx0.cancel(context.Canceled) |
| <-donec |
| } |
| |
| func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) { |
| ctx0 := &afterFuncContext{} |
| _, cancel1 := context.WithCancel(ctx0) |
| _, cancel2 := context.WithCancel(ctx0) |
| if got, want := len(ctx0.afterFuncs), 2; got != want { |
| t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) |
| } |
| cancel1() |
| cancel2() |
| if got, want := len(ctx0.afterFuncs), 0; got != want { |
| t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) |
| } |
| } |
| |
| func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) { |
| ctx0 := &afterFuncContext{} |
| _, cancel := context.WithTimeout(ctx0, veryLongDuration) |
| if got, want := len(ctx0.afterFuncs), 1; got != want { |
| t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want) |
| } |
| cancel() |
| if got, want := len(ctx0.afterFuncs), 0; got != want { |
| t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want) |
| } |
| } |
| |
| func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) { |
| ctx0 := &afterFuncContext{} |
| stop := context.AfterFunc(ctx0, func() {}) |
| if got, want := len(ctx0.afterFuncs), 1; got != want { |
| t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want) |
| } |
| stop() |
| if got, want := len(ctx0.afterFuncs), 0; got != want { |
| t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want) |
| } |
| } |