blob: 4671f8340e23bcbeb2e2496e1dcfedf0be65711b [file] [log] [blame]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"context"
"errors"
"fmt"
"path/filepath"
"runtime"
"sync"
)
// asyncTestState permits handling asynchronous operations in a synchronous test.
//
// For example, a test may want to write to a stream and observe that
// STREAM frames are sent with the contents of the write in response
// to MAX_STREAM_DATA frames received from the peer.
// The Stream.Write is an asynchronous operation, but the test is simpler
// if we can start the write, observe the first STREAM frame sent,
// send a MAX_STREAM_DATA frame, observe the next STREAM frame sent, etc.
//
// We do this by instrumenting points where operations can block.
// We start async operations like Write in a goroutine,
// and wait for the operation to either finish or hit a blocking point.
// When the connection event loop is idle, we check a list of
// blocked operations to see if any can be woken.
type asyncTestState struct {
mu sync.Mutex
notify chan struct{}
blocked map[*blockedAsync]struct{}
}
// An asyncOp is an asynchronous operation that results in (T, error).
type asyncOp[T any] struct {
v T
err error
caller string
tc *testConn
donec chan struct{}
cancelFunc context.CancelFunc
}
// cancel cancels the async operation's context, and waits for
// the operation to complete.
func (a *asyncOp[T]) cancel() {
select {
case <-a.donec:
return // already done
default:
}
a.cancelFunc()
<-a.tc.asyncTestState.notify
select {
case <-a.donec:
default:
panic(fmt.Errorf("%v: async op failed to finish after being canceled", a.caller))
}
}
var errNotDone = errors.New("async op is not done")
// result returns the result of the async operation.
// It returns errNotDone if the operation is still in progress.
//
// Note that unlike a traditional async/await, this doesn't block
// waiting for the operation to complete. Since tests have full
// control over the progress of operations, an asyncOp can only
// become done in reaction to the test taking some action.
func (a *asyncOp[T]) result() (v T, err error) {
a.tc.wait()
select {
case <-a.donec:
return a.v, a.err
default:
return v, errNotDone
}
}
// A blockedAsync is a blocked async operation.
type blockedAsync struct {
until func() bool // when this returns true, the operation is unblocked
donec chan struct{} // closed when the operation is unblocked
}
type asyncContextKey struct{}
// runAsync starts an asynchronous operation.
//
// The function f should call a blocking function such as
// Stream.Write or Conn.AcceptStream and return its result.
// It must use the provided context.
func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[T] {
as := &tc.asyncTestState
if as.notify == nil {
as.notify = make(chan struct{})
as.mu.Lock()
as.blocked = make(map[*blockedAsync]struct{})
as.mu.Unlock()
}
_, file, line, _ := runtime.Caller(1)
ctx := context.WithValue(context.Background(), asyncContextKey{}, true)
ctx, cancel := context.WithCancel(ctx)
a := &asyncOp[T]{
tc: tc,
caller: fmt.Sprintf("%v:%v", filepath.Base(file), line),
donec: make(chan struct{}),
cancelFunc: cancel,
}
go func() {
a.v, a.err = f(ctx)
close(a.donec)
as.notify <- struct{}{}
}()
tc.t.Cleanup(func() {
if _, err := a.result(); err == errNotDone {
tc.t.Errorf("%v: async operation is still executing at end of test", a.caller)
a.cancel()
}
})
// Wait for the operation to either finish or block.
<-as.notify
tc.wait()
return a
}
// waitUntil waits for a blocked async operation to complete.
// The operation is complete when the until func returns true.
func (as *asyncTestState) waitUntil(ctx context.Context, until func() bool) error {
if until() {
return nil
}
if err := ctx.Err(); err != nil {
// Context has already expired.
return err
}
if ctx.Value(asyncContextKey{}) == nil {
// Context is not one that we've created, and hasn't expired.
// This probably indicates that we've tried to perform a
// blocking operation without using the async test harness here,
// which may have unpredictable results.
panic("blocking async point with unexpected Context")
}
b := &blockedAsync{
until: until,
donec: make(chan struct{}),
}
// Record this as a pending blocking operation.
as.mu.Lock()
as.blocked[b] = struct{}{}
as.mu.Unlock()
// Notify the creator of the operation that we're blocked,
// and wait to be woken up.
as.notify <- struct{}{}
select {
case <-b.donec:
case <-ctx.Done():
return ctx.Err()
}
return nil
}
// wakeAsync tries to wake up a blocked async operation.
// It returns true if one was woken, false otherwise.
func (as *asyncTestState) wakeAsync() bool {
as.mu.Lock()
var woken *blockedAsync
for w := range as.blocked {
if w.until() {
woken = w
delete(as.blocked, w)
break
}
}
as.mu.Unlock()
if woken == nil {
return false
}
close(woken.donec)
<-as.notify // must not hold as.mu while blocked here
return true
}