blob: 4237b14364a39f273e60d52516f6644a4d8f3b69 [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Infrastructure for testing ClientConn.RoundTrip.
// Put actual tests in transport_test.go.
package http2
import (
"bytes"
"fmt"
"io"
"net"
"net/http"
"reflect"
"slices"
"testing"
"time"
"golang.org/x/net/http2/hpack"
)
// TestTestClientConn demonstrates usage of testClientConn.
func TestTestClientConn(t *testing.T) {
// newTestClientConn creates a *ClientConn and surrounding test infrastructure.
tc := newTestClientConn(t)
// tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames,
// and sends a SETTINGS frame to the client.
//
// Additional settings may be provided as optional parameters to greet.
tc.greet()
// Request bodies must either be constant (bytes.Buffer, strings.Reader)
// or created with newRequestBody.
body := tc.newRequestBody()
body.writeBytes(10) // 10 arbitrary bytes...
body.closeWithError(io.EOF) // ...followed by EOF.
// tc.roundTrip calls RoundTrip, but does not wait for it to return.
// It returns a testRoundTrip.
req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
rt := tc.roundTrip(req)
// tc has a number of methods to check for expected frames sent.
// Here, we look for headers and the request body.
tc.wantHeaders(wantHeader{
streamID: rt.streamID(),
endStream: false,
header: http.Header{
":authority": []string{"dummy.tld"},
":method": []string{"PUT"},
":path": []string{"/"},
},
})
// Expect 10 bytes of request body in DATA frames.
tc.wantData(wantData{
streamID: rt.streamID(),
endStream: true,
size: 10,
})
// tc.writeHeaders sends a HEADERS frame back to the client.
tc.writeHeaders(HeadersFrameParam{
StreamID: rt.streamID(),
EndHeaders: true,
EndStream: true,
BlockFragment: tc.makeHeaderBlockFragment(
":status", "200",
),
})
// Now that we've received headers, RoundTrip has finished.
// testRoundTrip has various methods to examine the response,
// or to fetch the response and/or error returned by RoundTrip
rt.wantStatus(200)
rt.wantBody(nil)
}
// A testClientConn allows testing ClientConn.RoundTrip against a fake server.
//
// A test using testClientConn consists of:
// - actions on the client (calling RoundTrip, making data available to Request.Body);
// - validation of frames sent by the client to the server; and
// - providing frames from the server to the client.
//
// testClientConn manages synchronization, so tests can generally be written as
// a linear sequence of actions and validations without additional synchronization.
type testClientConn struct {
t *testing.T
tr *Transport
fr *Framer
cc *ClientConn
hooks *testSyncHooks
encbuf bytes.Buffer
enc *hpack.Encoder
roundtrips []*testRoundTrip
rerr error // returned by Read
netConnClosed bool // set when the ClientConn closes the net.Conn
rbuf bytes.Buffer // sent to the test conn
wbuf bytes.Buffer // sent by the test conn
}
func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
tc := &testClientConn{
t: t,
tr: cc.t,
cc: cc,
hooks: cc.t.syncHooks,
}
cc.tconn = (*testClientConnNetConn)(tc)
tc.enc = hpack.NewEncoder(&tc.encbuf)
tc.fr = NewFramer(&tc.rbuf, &tc.wbuf)
tc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
tc.fr.SetMaxReadFrameSize(10 << 20)
t.Cleanup(func() {
tc.sync()
if tc.rerr == nil {
tc.rerr = io.EOF
}
tc.sync()
})
return tc
}
func (tc *testClientConn) readClientPreface() {
tc.t.Helper()
// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
buf := make([]byte, len(clientPreface))
if _, err := io.ReadFull(&tc.wbuf, buf); err != nil {
tc.t.Fatalf("reading preface: %v", err)
}
if !bytes.Equal(buf, clientPreface) {
tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
}
}
func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
t.Helper()
tt := newTestTransport(t, opts...)
const singleUse = false
_, err := tt.tr.newClientConn(nil, singleUse, tt.tr.syncHooks)
if err != nil {
t.Fatalf("newClientConn: %v", err)
}
return tt.getConn()
}
// sync waits for the ClientConn under test to reach a stable state,
// with all goroutines blocked on some input.
func (tc *testClientConn) sync() {
tc.hooks.waitInactive()
}
// advance advances synthetic time by a duration.
func (tc *testClientConn) advance(d time.Duration) {
tc.hooks.advance(d)
tc.sync()
}
// hasFrame reports whether a frame is available to be read.
func (tc *testClientConn) hasFrame() bool {
return tc.wbuf.Len() > 0
}
// readFrame reads the next frame from the conn.
func (tc *testClientConn) readFrame() Frame {
if tc.wbuf.Len() == 0 {
return nil
}
fr, err := tc.fr.ReadFrame()
if err != nil {
return nil
}
return fr
}
// testClientConnReadFrame reads a frame of a specific type from the conn.
func testClientConnReadFrame[T any](tc *testClientConn) T {
tc.t.Helper()
var v T
fr := tc.readFrame()
if fr == nil {
tc.t.Fatalf("got no frame, want frame %T", v)
}
v, ok := fr.(T)
if !ok {
tc.t.Fatalf("got frame %T, want %T", fr, v)
}
return v
}
// wantFrameType reads the next frame from the conn.
// It produces an error if the frame type is not the expected value.
func (tc *testClientConn) wantFrameType(want FrameType) {
tc.t.Helper()
fr := tc.readFrame()
if fr == nil {
tc.t.Fatalf("got no frame, want frame %v", want)
}
if got := fr.Header().Type; got != want {
tc.t.Fatalf("got frame %v, want %v", got, want)
}
}
// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied.
//
// want is a list of func(*SomeFrame) bool.
// wantUnorderedFrames will call each func with frames of the appropriate type
// until the func returns true.
// It calls t.Fatal if an unexpected frame is received (no func has that frame type,
// or all funcs with that type have returned true), or if the conn runs out of frames
// with unsatisfied funcs.
//
// Example:
//
// // Read a SETTINGS frame, and any number of DATA frames for a stream.
// // The SETTINGS frame may appear anywhere in the sequence.
// // The last DATA frame must indicate the end of the stream.
// tc.wantUnorderedFrames(
// func(f *SettingsFrame) bool {
// return true
// },
// func(f *DataFrame) bool {
// return f.StreamEnded()
// },
// )
func (tc *testClientConn) wantUnorderedFrames(want ...any) {
tc.t.Helper()
want = slices.Clone(want)
seen := 0
frame:
for seen < len(want) && !tc.t.Failed() {
fr := tc.readFrame()
if fr == nil {
break
}
for i, f := range want {
if f == nil {
continue
}
typ := reflect.TypeOf(f)
if typ.Kind() != reflect.Func ||
typ.NumIn() != 1 ||
typ.NumOut() != 1 ||
typ.Out(0) != reflect.TypeOf(true) {
tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
}
if typ.In(0) == reflect.TypeOf(fr) {
out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
if out[0].Bool() {
want[i] = nil
seen++
}
continue frame
}
}
tc.t.Errorf("got unexpected frame type %T", fr)
}
if seen < len(want) {
for _, f := range want {
if f == nil {
continue
}
tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
}
tc.t.Fatalf("did not see %v expected frame types", len(want)-seen)
}
}
type wantHeader struct {
streamID uint32
endStream bool
header http.Header
}
// wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
// and asserts that they contain the expected headers.
func (tc *testClientConn) wantHeaders(want wantHeader) {
tc.t.Helper()
got := testClientConnReadFrame[*MetaHeadersFrame](tc)
if got, want := got.StreamID, want.streamID; got != want {
tc.t.Fatalf("got stream ID %v, want %v", got, want)
}
if got, want := got.StreamEnded(), want.endStream; got != want {
tc.t.Fatalf("got stream ended %v, want %v", got, want)
}
gotHeader := make(http.Header)
for _, f := range got.Fields {
gotHeader[f.Name] = append(gotHeader[f.Name], f.Value)
}
for k, v := range want.header {
if !reflect.DeepEqual(v, gotHeader[k]) {
tc.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
}
}
}
type wantData struct {
streamID uint32
endStream bool
size int
}
// wantData reads zero or more DATA frames, and asserts that they match the expectation.
func (tc *testClientConn) wantData(want wantData) {
tc.t.Helper()
gotSize := 0
gotEndStream := false
for tc.hasFrame() && !gotEndStream {
data := testClientConnReadFrame[*DataFrame](tc)
gotSize += len(data.Data())
if data.StreamEnded() {
gotEndStream = true
}
}
if gotSize != want.size {
tc.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
}
if gotEndStream != want.endStream {
tc.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
}
}
// testRequestBody is a Request.Body for use in tests.
type testRequestBody struct {
tc *testClientConn
// At most one of buf or bytes can be set at any given time:
buf bytes.Buffer // specific bytes to read from the body
bytes int // body contains this many arbitrary bytes
err error // read error (comes after any available bytes)
}
func (tc *testClientConn) newRequestBody() *testRequestBody {
b := &testRequestBody{
tc: tc,
}
return b
}
// Read is called by the ClientConn to read from a request body.
func (b *testRequestBody) Read(p []byte) (n int, _ error) {
b.tc.cc.syncHooks.blockUntil(func() bool {
return b.buf.Len() > 0 || b.bytes > 0 || b.err != nil
})
switch {
case b.buf.Len() > 0:
return b.buf.Read(p)
case b.bytes > 0:
if len(p) > b.bytes {
p = p[:b.bytes]
}
b.bytes -= len(p)
for i := range p {
p[i] = 'A'
}
return len(p), nil
default:
return 0, b.err
}
}
// Close is called by the ClientConn when it is done reading from a request body.
func (b *testRequestBody) Close() error {
return nil
}
// writeBytes adds n arbitrary bytes to the body.
func (b *testRequestBody) writeBytes(n int) {
b.bytes += n
b.checkWrite()
b.tc.sync()
}
// Write adds bytes to the body.
func (b *testRequestBody) Write(p []byte) (int, error) {
n, err := b.buf.Write(p)
b.checkWrite()
b.tc.sync()
return n, err
}
func (b *testRequestBody) checkWrite() {
if b.bytes > 0 && b.buf.Len() > 0 {
b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
}
if b.err != nil {
b.tc.t.Fatalf("can't write to request body after closeWithError")
}
}
// closeWithError sets an error which will be returned by Read.
func (b *testRequestBody) closeWithError(err error) {
b.err = err
b.tc.sync()
}
// roundTrip starts a RoundTrip call.
//
// (Note that the RoundTrip won't complete until response headers are received,
// the request times out, or some other terminal condition is reached.)
func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
rt := &testRoundTrip{
t: tc.t,
donec: make(chan struct{}),
}
tc.roundtrips = append(tc.roundtrips, rt)
tc.hooks.newstream = func(cs *clientStream) { rt.cs = cs }
tc.cc.goRun(func() {
defer close(rt.donec)
rt.resp, rt.respErr = tc.cc.RoundTrip(req)
})
tc.sync()
tc.hooks.newstream = nil
tc.t.Cleanup(func() {
if !rt.done() {
return
}
res, _ := rt.result()
if res != nil {
res.Body.Close()
}
})
return rt
}
func (tc *testClientConn) greet(settings ...Setting) {
tc.wantFrameType(FrameSettings)
tc.wantFrameType(FrameWindowUpdate)
tc.writeSettings(settings...)
tc.writeSettingsAck()
tc.wantFrameType(FrameSettings) // acknowledgement
}
func (tc *testClientConn) writeSettings(settings ...Setting) {
tc.t.Helper()
if err := tc.fr.WriteSettings(settings...); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeSettingsAck() {
tc.t.Helper()
if err := tc.fr.WriteSettingsAck(); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeData(streamID uint32, endStream bool, data []byte) {
tc.t.Helper()
if err := tc.fr.WriteData(streamID, endStream, data); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
tc.t.Helper()
if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
// makeHeaderBlockFragment encodes headers in a form suitable for inclusion
// in a HEADERS or CONTINUATION frame.
//
// It takes a list of alernating names and values.
func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
if len(s)%2 != 0 {
tc.t.Fatalf("uneven list of header name/value pairs")
}
tc.encbuf.Reset()
for i := 0; i < len(s); i += 2 {
tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
}
return tc.encbuf.Bytes()
}
func (tc *testClientConn) writeHeaders(p HeadersFrameParam) {
tc.t.Helper()
if err := tc.fr.WriteHeaders(p); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
// writeHeadersMode writes header frames, as modified by mode:
//
// - noHeader: Don't write the header.
// - oneHeader: Write a single HEADERS frame.
// - splitHeader: Write a HEADERS frame and CONTINUATION frame.
func (tc *testClientConn) writeHeadersMode(mode headerType, p HeadersFrameParam) {
tc.t.Helper()
switch mode {
case noHeader:
case oneHeader:
tc.writeHeaders(p)
case splitHeader:
if len(p.BlockFragment) < 2 {
panic("too small")
}
contData := p.BlockFragment[1:]
contEnd := p.EndHeaders
p.BlockFragment = p.BlockFragment[:1]
p.EndHeaders = false
tc.writeHeaders(p)
tc.writeContinuation(p.StreamID, contEnd, contData)
default:
panic("bogus mode")
}
}
func (tc *testClientConn) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
tc.t.Helper()
if err := tc.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeRSTStream(streamID uint32, code ErrCode) {
tc.t.Helper()
if err := tc.fr.WriteRSTStream(streamID, code); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writePing(ack bool, data [8]byte) {
tc.t.Helper()
if err := tc.fr.WritePing(ack, data); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
tc.t.Helper()
if err := tc.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
func (tc *testClientConn) writeWindowUpdate(streamID, incr uint32) {
tc.t.Helper()
if err := tc.fr.WriteWindowUpdate(streamID, incr); err != nil {
tc.t.Fatal(err)
}
tc.sync()
}
// closeWrite causes the net.Conn used by the ClientConn to return a error
// from Read calls.
func (tc *testClientConn) closeWrite(err error) {
tc.rerr = err
tc.sync()
}
// inflowWindow returns the amount of inbound flow control available for a stream,
// or for the connection if streamID is 0.
func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
tc.cc.mu.Lock()
defer tc.cc.mu.Unlock()
if streamID == 0 {
return tc.cc.inflow.avail + tc.cc.inflow.unsent
}
cs := tc.cc.streams[streamID]
if cs == nil {
tc.t.Errorf("no stream with id %v", streamID)
return -1
}
return cs.inflow.avail + cs.inflow.unsent
}
// testRoundTrip manages a RoundTrip in progress.
type testRoundTrip struct {
t *testing.T
resp *http.Response
respErr error
donec chan struct{}
cs *clientStream
}
// streamID returns the HTTP/2 stream ID of the request.
func (rt *testRoundTrip) streamID() uint32 {
if rt.cs == nil {
panic("stream ID unknown")
}
return rt.cs.ID
}
// done reports whether RoundTrip has returned.
func (rt *testRoundTrip) done() bool {
select {
case <-rt.donec:
return true
default:
return false
}
}
// result returns the result of the RoundTrip.
func (rt *testRoundTrip) result() (*http.Response, error) {
t := rt.t
t.Helper()
select {
case <-rt.donec:
default:
t.Fatalf("RoundTrip is not done; want it to be")
}
return rt.resp, rt.respErr
}
// response returns the response of a successful RoundTrip.
// If the RoundTrip unexpectedly failed, it calls t.Fatal.
func (rt *testRoundTrip) response() *http.Response {
t := rt.t
t.Helper()
resp, err := rt.result()
if err != nil {
t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
}
if resp == nil {
t.Fatalf("RoundTrip returned nil *Response and nil error")
}
return resp
}
// err returns the (possibly nil) error result of RoundTrip.
func (rt *testRoundTrip) err() error {
t := rt.t
t.Helper()
_, err := rt.result()
return err
}
// wantStatus indicates the expected response StatusCode.
func (rt *testRoundTrip) wantStatus(want int) {
t := rt.t
t.Helper()
if got := rt.response().StatusCode; got != want {
t.Fatalf("got response status %v, want %v", got, want)
}
}
// body reads the contents of the response body.
func (rt *testRoundTrip) readBody() ([]byte, error) {
t := rt.t
t.Helper()
return io.ReadAll(rt.response().Body)
}
// wantBody indicates the expected response body.
// (Note that this consumes the body.)
func (rt *testRoundTrip) wantBody(want []byte) {
t := rt.t
t.Helper()
got, err := rt.readBody()
if err != nil {
t.Fatalf("unexpected error reading response body: %v", err)
}
if !bytes.Equal(got, want) {
t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want)
}
}
// wantHeaders indicates the expected response headers.
func (rt *testRoundTrip) wantHeaders(want http.Header) {
t := rt.t
t.Helper()
res := rt.response()
if diff := diffHeaders(res.Header, want); diff != "" {
t.Fatalf("unexpected response headers:\n%v", diff)
}
}
// wantTrailers indicates the expected response trailers.
func (rt *testRoundTrip) wantTrailers(want http.Header) {
t := rt.t
t.Helper()
res := rt.response()
if diff := diffHeaders(res.Trailer, want); diff != "" {
t.Fatalf("unexpected response trailers:\n%v", diff)
}
}
func diffHeaders(got, want http.Header) string {
// nil and 0-length non-nil are equal.
if len(got) == 0 && len(want) == 0 {
return ""
}
// We could do a more sophisticated diff here.
// DeepEqual is good enough for now.
if reflect.DeepEqual(got, want) {
return ""
}
return fmt.Sprintf("got: %v\nwant: %v", got, want)
}
// testClientConnNetConn implements net.Conn.
type testClientConnNetConn testClientConn
func (nc *testClientConnNetConn) Read(b []byte) (n int, err error) {
nc.cc.syncHooks.blockUntil(func() bool {
return nc.rerr != nil || nc.rbuf.Len() > 0
})
if nc.rbuf.Len() > 0 {
return nc.rbuf.Read(b)
}
return 0, nc.rerr
}
func (nc *testClientConnNetConn) Write(b []byte) (n int, err error) {
return nc.wbuf.Write(b)
}
func (nc *testClientConnNetConn) Close() error {
nc.netConnClosed = true
return nil
}
func (*testClientConnNetConn) LocalAddr() (_ net.Addr) { return }
func (*testClientConnNetConn) RemoteAddr() (_ net.Addr) { return }
func (*testClientConnNetConn) SetDeadline(t time.Time) error { return nil }
func (*testClientConnNetConn) SetReadDeadline(t time.Time) error { return nil }
func (*testClientConnNetConn) SetWriteDeadline(t time.Time) error { return nil }
// A testTransport allows testing Transport.RoundTrip against fake servers.
// Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
// should use testClientConn instead.
type testTransport struct {
t *testing.T
tr *Transport
ccs []*testClientConn
}
func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport {
tr := &Transport{
syncHooks: newTestSyncHooks(),
}
for _, o := range opts {
o(tr)
}
tt := &testTransport{
t: t,
tr: tr,
}
tr.syncHooks.newclientconn = func(cc *ClientConn) {
tt.ccs = append(tt.ccs, newTestClientConnFromClientConn(t, cc))
}
t.Cleanup(func() {
tt.sync()
if len(tt.ccs) > 0 {
t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
}
if tt.tr.syncHooks.total != 0 {
t.Errorf("%v goroutines still running after test completed", tt.tr.syncHooks.total)
}
})
return tt
}
func (tt *testTransport) sync() {
tt.tr.syncHooks.waitInactive()
}
func (tt *testTransport) advance(d time.Duration) {
tt.tr.syncHooks.advance(d)
tt.sync()
}
func (tt *testTransport) hasConn() bool {
return len(tt.ccs) > 0
}
func (tt *testTransport) getConn() *testClientConn {
tt.t.Helper()
if len(tt.ccs) == 0 {
tt.t.Fatalf("no new ClientConns created; wanted one")
}
tc := tt.ccs[0]
tt.ccs = tt.ccs[1:]
tc.sync()
tc.readClientPreface()
return tc
}
func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
rt := &testRoundTrip{
t: tt.t,
donec: make(chan struct{}),
}
tt.tr.syncHooks.goRun(func() {
defer close(rt.donec)
rt.resp, rt.respErr = tt.tr.RoundTrip(req)
})
tt.sync()
tt.t.Cleanup(func() {
if !rt.done() {
return
}
res, _ := rt.result()
if res != nil {
res.Body.Close()
}
})
return rt
}