blob: 7b75295eb4dab03363efee1f56fbfa693edeca78 [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context_test
import (
"context"
"sync"
"testing"
"time"
)
// afterFuncContext is a context that's not one of the types
// defined in context.go, that supports registering AfterFuncs.
type afterFuncContext struct {
mu sync.Mutex
afterFuncs map[*byte]func()
done chan struct{}
err error
}
func newAfterFuncContext() context.Context {
return &afterFuncContext{}
}
func (c *afterFuncContext) Deadline() (time.Time, bool) {
return time.Time{}, false
}
func (c *afterFuncContext) Done() <-chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
if c.done == nil {
c.done = make(chan struct{})
}
return c.done
}
func (c *afterFuncContext) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}
func (c *afterFuncContext) Value(key any) any {
return nil
}
func (c *afterFuncContext) AfterFunc(f func()) func() bool {
c.mu.Lock()
defer c.mu.Unlock()
k := new(byte)
if c.afterFuncs == nil {
c.afterFuncs = make(map[*byte]func())
}
c.afterFuncs[k] = f
return func() bool {
c.mu.Lock()
defer c.mu.Unlock()
_, ok := c.afterFuncs[k]
delete(c.afterFuncs, k)
return ok
}
}
func (c *afterFuncContext) cancel(err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.err != nil {
return
}
c.err = err
for _, f := range c.afterFuncs {
go f()
}
c.afterFuncs = nil
}
func TestCustomContextAfterFuncCancel(t *testing.T) {
ctx0 := &afterFuncContext{}
ctx1, cancel := context.WithCancel(ctx0)
defer cancel()
ctx0.cancel(context.Canceled)
<-ctx1.Done()
}
func TestCustomContextAfterFuncTimeout(t *testing.T) {
ctx0 := &afterFuncContext{}
ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
defer cancel()
ctx0.cancel(context.Canceled)
<-ctx1.Done()
}
func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
ctx0 := &afterFuncContext{}
donec := make(chan struct{})
stop := context.AfterFunc(ctx0, func() {
close(donec)
})
defer stop()
ctx0.cancel(context.Canceled)
<-donec
}
func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
ctx0 := &afterFuncContext{}
_, cancel1 := context.WithCancel(ctx0)
_, cancel2 := context.WithCancel(ctx0)
if got, want := len(ctx0.afterFuncs), 2; got != want {
t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
}
cancel1()
cancel2()
if got, want := len(ctx0.afterFuncs), 0; got != want {
t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
}
}
func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
ctx0 := &afterFuncContext{}
_, cancel := context.WithTimeout(ctx0, veryLongDuration)
if got, want := len(ctx0.afterFuncs), 1; got != want {
t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
}
cancel()
if got, want := len(ctx0.afterFuncs), 0; got != want {
t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
}
}
func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
ctx0 := &afterFuncContext{}
stop := context.AfterFunc(ctx0, func() {})
if got, want := len(ctx0.afterFuncs), 1; got != want {
t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
}
stop()
if got, want := len(ctx0.afterFuncs), 0; got != want {
t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
}
}