| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Infrastructure for testing ClientConn.RoundTrip. |
| // Put actual tests in transport_test.go. |
| |
| package http2_test |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto/tls" |
| "errors" |
| "fmt" |
| "io" |
| "net" |
| "net/http" |
| "reflect" |
| "slices" |
| "sync" |
| "sync/atomic" |
| "testing" |
| "testing/synctest" |
| "time" |
| |
| "golang.org/x/net/http2" |
| . "golang.org/x/net/http2" |
| "golang.org/x/net/http2/hpack" |
| "golang.org/x/net/internal/gate" |
| ) |
| |
| // TestTestClientConn demonstrates usage of testClientConn. |
| func TestTestClientConn(t *testing.T) { synctestTest(t, testTestClientConn) } |
| func testTestClientConn(t testing.TB) { |
| // newTestClientConn creates a *ClientConn and surrounding test infrastructure. |
| tc := newTestClientConn(t) |
| |
| // tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames, |
| // and sends a SETTINGS frame to the client. |
| // |
| // Additional settings may be provided as optional parameters to greet. |
| tc.greet() |
| |
| // Request bodies must either be constant (bytes.Buffer, strings.Reader) |
| // or created with newRequestBody. |
| body := tc.newRequestBody() |
| body.writeBytes(10) // 10 arbitrary bytes... |
| body.closeWithError(io.EOF) // ...followed by EOF. |
| |
| // tc.roundTrip calls RoundTrip, but does not wait for it to return. |
| // It returns a testRoundTrip. |
| req, _ := http.NewRequest("PUT", "https://dummy.tld/", body) |
| rt := tc.roundTrip(req) |
| |
| // tc has a number of methods to check for expected frames sent. |
| // Here, we look for headers and the request body. |
| tc.wantHeaders(wantHeader{ |
| streamID: rt.streamID(), |
| endStream: false, |
| header: http.Header{ |
| ":authority": []string{"dummy.tld"}, |
| ":method": []string{"PUT"}, |
| ":path": []string{"/"}, |
| }, |
| }) |
| // Expect 10 bytes of request body in DATA frames. |
| tc.wantData(wantData{ |
| streamID: rt.streamID(), |
| endStream: true, |
| size: 10, |
| multiple: true, |
| }) |
| |
| // tc.writeHeaders sends a HEADERS frame back to the client. |
| tc.writeHeaders(HeadersFrameParam{ |
| StreamID: rt.streamID(), |
| EndHeaders: true, |
| EndStream: true, |
| BlockFragment: tc.makeHeaderBlockFragment( |
| ":status", "200", |
| ), |
| }) |
| |
| // Now that we've received headers, RoundTrip has finished. |
| // testRoundTrip has various methods to examine the response, |
| // or to fetch the response and/or error returned by RoundTrip |
| rt.wantStatus(200) |
| rt.wantBody(nil) |
| } |
| |
| func TestTestTransport(t *testing.T) { |
| synctestSubtest(t, "nethttp", func(t testing.TB) { |
| testTestTransport(t, roundTripNetHTTP) |
| }) |
| synctestSubtest(t, "xnethttp2", func(t testing.TB) { |
| testTestTransport(t, roundTripXNetHTTP2) |
| }) |
| } |
| func testTestTransport(t testing.TB, mode roundTripTestMode) { |
| tt := newTestTransport(t) |
| |
| req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) |
| rt := tt.roundTrip(req) |
| tc := tt.getConn() |
| tc.wantFrameType(FrameSettings) |
| tc.wantFrameType(FrameWindowUpdate) |
| |
| tc.wantHeaders(wantHeader{ |
| streamID: 1, |
| endStream: true, |
| header: http.Header{ |
| ":authority": []string{"dummy.tld"}, |
| ":method": []string{"GET"}, |
| ":path": []string{"/"}, |
| }, |
| }) |
| |
| tc.writeSettings() |
| tc.writeSettingsAck() |
| tc.wantFrameType(FrameSettings) // acknowledgement |
| tc.writeHeaders(HeadersFrameParam{ |
| StreamID: 1, |
| EndHeaders: true, |
| EndStream: true, |
| BlockFragment: tc.makeHeaderBlockFragment( |
| ":status", "200", |
| ), |
| }) |
| |
| rt.wantStatus(200) |
| } |
| |
| // A testClientConn allows testing ClientConn.RoundTrip against a fake server. |
| // |
| // A test using testClientConn consists of: |
| // - actions on the client (calling RoundTrip, making data available to Request.Body); |
| // - validation of frames sent by the client to the server; and |
| // - providing frames from the server to the client. |
| // |
| // testClientConn manages synchronization, so tests can generally be written as |
| // a linear sequence of actions and validations without additional synchronization. |
| type testClientConn struct { |
| t testing.TB |
| |
| tr *Transport |
| fr *Framer |
| cc *ClientConn |
| cc1 *httpClientConn |
| testConnFramer |
| |
| encbuf bytes.Buffer |
| enc *hpack.Encoder |
| |
| netconn *synctestNetConn |
| connReader *nonblockingReader |
| } |
| |
| func newTestClientConnFromNetConn(tt *testTransport, nc net.Conn) *testClientConn { |
| tc := &testClientConn{ |
| t: tt.t, |
| tr: tt.tr, |
| } |
| |
| var writer io.Writer |
| var reader io.Reader |
| if tt.useTLS { |
| tlsConfig := testTLSServerConfig.Clone() |
| tlsConfig.NextProtos = []string{"h2"} |
| tlsConn := tls.Server(nc, tlsConfig) |
| reader = tlsConn |
| writer = tlsConn |
| } else { |
| reader = nc |
| writer = nc |
| } |
| tc.connReader = newNonblockingReader(reader) |
| |
| tc.netconn = nc.(*synctestNetConn) |
| tc.enc = hpack.NewEncoder(&tc.encbuf) |
| tc.fr = NewFramer(writer, tc.connReader) |
| tc.testConnFramer = testConnFramer{ |
| t: tt.t, |
| fr: tc.fr, |
| dec: hpack.NewDecoder(InitialHeaderTableSize, nil), |
| } |
| tc.fr.SetMaxReadFrameSize(10 << 20) |
| tt.t.Cleanup(func() { |
| tc.closeWrite() |
| }) |
| |
| return tc |
| } |
| |
| func (tc *testClientConn) readClientPreface() { |
| tc.t.Helper() |
| // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. |
| synctest.Wait() |
| buf := make([]byte, len(ClientPreface)) |
| if _, err := io.ReadFull(tc.connReader, buf); err != nil { |
| tc.t.Fatalf("reading preface: %v", err) |
| } |
| if !bytes.Equal(buf, []byte(ClientPreface)) { |
| tc.t.Fatalf("client preface: %q, want %q", buf, ClientPreface) |
| } |
| } |
| |
| // hasFrame reports whether a frame is available to be read. |
| func (tc *testClientConn) hasFrame() bool { |
| synctest.Wait() |
| return tc.connReader.buf.Len() > 0 |
| } |
| |
| // isClosed reports whether the peer has closed the connection. |
| func (tc *testClientConn) isClosed() bool { |
| synctest.Wait() |
| return tc.netconn.IsClosedByPeer() |
| } |
| |
| // closeWrite causes the net.Conn used by the ClientConn to return a error |
| // from Read calls. |
| func (tc *testClientConn) closeWrite() { |
| tc.netconn.Close() |
| } |
| |
| // closeWrite causes the net.Conn used by the ClientConn to return a error |
| // from Write calls. |
| func (tc *testClientConn) closeWriteWithError(err error) { |
| tc.netconn.loc.setReadError(io.EOF) |
| tc.netconn.loc.setWriteError(err) |
| } |
| |
| // testRequestBody is a Request.Body for use in tests. |
| type testRequestBody struct { |
| tc *testClientConn |
| gate gate.Gate |
| |
| // At most one of buf or bytes can be set at any given time: |
| buf bytes.Buffer // specific bytes to read from the body |
| bytes int // body contains this many arbitrary bytes |
| |
| err error // read error (comes after any available bytes) |
| } |
| |
| func (tc *testClientConn) newRequestBody() *testRequestBody { |
| b := &testRequestBody{ |
| tc: tc, |
| gate: gate.New(false), |
| } |
| return b |
| } |
| |
| func (b *testRequestBody) unlock() { |
| b.gate.Unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil) |
| } |
| |
| // Read is called by the ClientConn to read from a request body. |
| func (b *testRequestBody) Read(p []byte) (n int, _ error) { |
| if err := b.gate.WaitAndLock(context.Background()); err != nil { |
| return 0, err |
| } |
| defer b.unlock() |
| switch { |
| case b.buf.Len() > 0: |
| return b.buf.Read(p) |
| case b.bytes > 0: |
| if len(p) > b.bytes { |
| p = p[:b.bytes] |
| } |
| b.bytes -= len(p) |
| for i := range p { |
| p[i] = 'A' |
| } |
| return len(p), nil |
| default: |
| return 0, b.err |
| } |
| } |
| |
| // Close is called by the ClientConn when it is done reading from a request body. |
| func (b *testRequestBody) Close() error { |
| return nil |
| } |
| |
| // writeBytes adds n arbitrary bytes to the body. |
| func (b *testRequestBody) writeBytes(n int) { |
| defer synctest.Wait() |
| b.gate.Lock() |
| defer b.unlock() |
| b.bytes += n |
| b.checkWrite() |
| synctest.Wait() |
| } |
| |
| // Write adds bytes to the body. |
| func (b *testRequestBody) Write(p []byte) (int, error) { |
| defer synctest.Wait() |
| b.gate.Lock() |
| defer b.unlock() |
| n, err := b.buf.Write(p) |
| b.checkWrite() |
| return n, err |
| } |
| |
| func (b *testRequestBody) checkWrite() { |
| if b.bytes > 0 && b.buf.Len() > 0 { |
| b.tc.t.Fatalf("can't interleave Write and writeBytes on request body") |
| } |
| if b.err != nil { |
| b.tc.t.Fatalf("can't write to request body after closeWithError") |
| } |
| } |
| |
| // closeWithError sets an error which will be returned by Read. |
| func (b *testRequestBody) closeWithError(err error) { |
| defer synctest.Wait() |
| b.gate.Lock() |
| defer b.unlock() |
| b.err = err |
| } |
| |
| // roundTrip starts a RoundTrip call. |
| // |
| // (Note that the RoundTrip won't complete until response headers are received, |
| // the request times out, or some other terminal condition is reached.) |
| func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { |
| rt := &testRoundTrip{} |
| rt.do(tc.t, req, func(req *http.Request) (*http.Response, error) { |
| if tc.cc1 != nil { |
| return tc.cc1.RoundTrip(req) |
| } |
| return tc.cc.TestRoundTrip(req, func(streamID uint32) { |
| rt.id.Store(streamID) |
| }) |
| }) |
| return rt |
| } |
| |
| func newTestRoundTrip(t testing.TB, req *http.Request, f func(*http.Request) (*http.Response, error)) *testRoundTrip { |
| rt := &testRoundTrip{} |
| rt.do(t, req, f) |
| return rt |
| } |
| |
| func (rt *testRoundTrip) do(t testing.TB, req *http.Request, f func(*http.Request) (*http.Response, error)) { |
| if rt.t != nil { |
| t.Fatal("testRoundTrip can only be used once") |
| } |
| ctx, cancel := context.WithCancel(req.Context()) |
| req = req.WithContext(ctx) |
| rt.t = t |
| rt.donec = make(chan struct{}) |
| rt.cancel = cancel |
| go func() { |
| defer close(rt.donec) |
| rt.resp, rt.respErr = f(req) |
| }() |
| synctest.Wait() |
| |
| t.Cleanup(func() { |
| rt.cancel() |
| if !rt.done() { |
| return |
| } |
| res, _ := rt.result() |
| if res != nil { |
| res.Body.Close() |
| } |
| }) |
| } |
| |
| func (tc *testClientConn) greet(settings ...Setting) { |
| tc.wantFrameType(FrameSettings) |
| tc.wantFrameType(FrameWindowUpdate) |
| tc.writeSettings(settings...) |
| tc.writeSettingsAck() |
| tc.wantFrameType(FrameSettings) // acknowledgement |
| } |
| |
| // makeHeaderBlockFragment encodes headers in a form suitable for inclusion |
| // in a HEADERS or CONTINUATION frame. |
| // |
| // It takes a list of alternating names and values. |
| func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { |
| if len(s)%2 != 0 { |
| tc.t.Fatalf("uneven list of header name/value pairs") |
| } |
| tc.encbuf.Reset() |
| for i := 0; i < len(s); i += 2 { |
| tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]}) |
| } |
| return tc.encbuf.Bytes() |
| } |
| |
| // inflowWindow returns the amount of inbound flow control available for a stream, |
| // or for the connection if streamID is 0. |
| func (tc *testClientConn) inflowWindow(streamID uint32) int32 { |
| synctest.Wait() |
| w, err := tc.cc.TestInflowWindow(streamID) |
| if err != nil { |
| tc.t.Error(err) |
| } |
| return w |
| } |
| |
| // testRoundTrip manages a RoundTrip in progress. |
| type testRoundTrip struct { |
| t testing.TB |
| resp *http.Response |
| respErr error |
| donec chan struct{} |
| id atomic.Uint32 |
| cancel context.CancelFunc |
| } |
| |
| // streamID returns the HTTP/2 stream ID of the request. |
| func (rt *testRoundTrip) streamID() uint32 { |
| synctest.Wait() |
| id := rt.id.Load() |
| if id == 0 { |
| panic("stream ID unknown") |
| } |
| return id |
| } |
| |
| // done reports whether RoundTrip has returned. |
| func (rt *testRoundTrip) done() bool { |
| synctest.Wait() |
| select { |
| case <-rt.donec: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| // result returns the result of the RoundTrip. |
| func (rt *testRoundTrip) result() (*http.Response, error) { |
| t := rt.t |
| t.Helper() |
| synctest.Wait() |
| select { |
| case <-rt.donec: |
| default: |
| t.Fatalf("RoundTrip is not done; want it to be") |
| } |
| return rt.resp, rt.respErr |
| } |
| |
| // response returns the response of a successful RoundTrip. |
| // If the RoundTrip unexpectedly failed, it calls t.Fatal. |
| func (rt *testRoundTrip) response() *http.Response { |
| t := rt.t |
| t.Helper() |
| resp, err := rt.result() |
| if err != nil { |
| t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) |
| } |
| if resp == nil { |
| t.Fatalf("RoundTrip returned nil *Response and nil error") |
| } |
| return resp |
| } |
| |
| // err returns the (possibly nil) error result of RoundTrip. |
| func (rt *testRoundTrip) err() error { |
| t := rt.t |
| t.Helper() |
| _, err := rt.result() |
| return err |
| } |
| |
| // wantStatus indicates the expected response StatusCode. |
| func (rt *testRoundTrip) wantStatus(want int) { |
| t := rt.t |
| t.Helper() |
| if got := rt.response().StatusCode; got != want { |
| t.Fatalf("got response status %v, want %v", got, want) |
| } |
| } |
| |
| // readBody reads the contents of the response body. |
| func (rt *testRoundTrip) readBody() ([]byte, error) { |
| t := rt.t |
| t.Helper() |
| return io.ReadAll(rt.response().Body) |
| } |
| |
| // wantBody indicates the expected response body. |
| // (Note that this consumes the body.) |
| func (rt *testRoundTrip) wantBody(want []byte) { |
| t := rt.t |
| t.Helper() |
| got, err := rt.readBody() |
| if err != nil { |
| t.Fatalf("unexpected error reading response body: %v", err) |
| } |
| if !bytes.Equal(got, want) { |
| t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want) |
| } |
| } |
| |
| // wantHeaders indicates the expected response headers. |
| func (rt *testRoundTrip) wantHeaders(want http.Header) { |
| t := rt.t |
| t.Helper() |
| res := rt.response() |
| if diff := diffHeaders(res.Header, want); diff != "" { |
| t.Fatalf("unexpected response headers:\n%v", diff) |
| } |
| } |
| |
| // wantTrailers indicates the expected response trailers. |
| func (rt *testRoundTrip) wantTrailers(want http.Header) { |
| t := rt.t |
| t.Helper() |
| res := rt.response() |
| if diff := diffHeaders(res.Trailer, want); diff != "" { |
| t.Fatalf("unexpected response trailers:\n%v", diff) |
| } |
| } |
| |
| func diffHeaders(got, want http.Header) string { |
| // nil and 0-length non-nil are equal. |
| if len(got) == 0 && len(want) == 0 { |
| return "" |
| } |
| // We could do a more sophisticated diff here. |
| // DeepEqual is good enough for now. |
| if reflect.DeepEqual(got, want) { |
| return "" |
| } |
| return fmt.Sprintf("got: %v\nwant: %v", got, want) |
| } |
| |
| // roundTripTestMode selects which RoundTrip API a test uses. |
| type roundTripTestMode int |
| |
| const ( |
| // roundTripNetHTTP uses net/http.Transport.RoundTrip or |
| // net/http.ClientConn.RoundTrip: |
| // |
| // t1 := http.Transport{} |
| // t2 := ConfigureTransports(t1) |
| // resp, err := t1.RoundTrip(req) |
| // |
| roundTripNetHTTP = roundTripTestMode(iota) |
| |
| // roundTripXNetHTTP2 uses x/net/http2.Transport.RoundTrip or |
| // x/net/http2.ClientConn.RoundTrip: |
| // |
| // t2 := http2.Transport{} |
| // resp, err := t2.RoundTrip(req) |
| roundTripXNetHTTP2 |
| ) |
| |
| // A testTransport allows testing Transport.RoundTrip against fake servers. |
| // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling |
| // should use testClientConn instead. |
| type testTransport struct { |
| t testing.TB |
| tr *Transport |
| tr1 *http.Transport |
| li *synctestNetListener |
| mode roundTripTestMode |
| |
| ccMu sync.Mutex |
| ccqueue []*testClientConn |
| ccs map[*synctestNetConn]*testClientConn |
| |
| ccpending []*testPendingClientConn |
| |
| useTLS bool |
| } |
| |
| type testPendingClientConn struct { |
| nc *synctestNetConn |
| cc *ClientConn |
| tc *testClientConn |
| } |
| |
| func newTestTransport(t testing.TB, opts ...any) *testTransport { |
| tt := &testTransport{ |
| t: t, |
| li: newSynctestNetListener(), |
| ccs: make(map[*synctestNetConn]*testClientConn), |
| mode: roundTripXNetHTTP2, |
| } |
| |
| for _, o := range opts { |
| switch o := o.(type) { |
| case roundTripTestMode: |
| tt.mode = o |
| } |
| } |
| |
| var ( |
| tr *Transport |
| tr1 *http.Transport |
| ) |
| switch tt.mode { |
| case roundTripXNetHTTP2: |
| tr = &Transport{ |
| DialTLSContext: func(ctx context.Context, network, address string, tlsConf *tls.Config) (net.Conn, error) { |
| return tt.li.newConn(), nil |
| }, |
| AllowHTTP: true, |
| } |
| case roundTripNetHTTP: |
| tr1 = &http.Transport{ |
| DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { |
| return tt.li.newConn(), nil |
| }, |
| Protocols: &http.Protocols{}, |
| TLSClientConfig: testTLSClientConfig, |
| } |
| tr1.Protocols.SetHTTP2(true) |
| tr1.Protocols.SetUnencryptedHTTP2(true) |
| t.Cleanup(tr1.CloseIdleConnections) |
| var err error |
| tr, err = ConfigureTransports(tr1) |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| for _, o := range opts { |
| switch o := o.(type) { |
| case func(*http.Transport): |
| o(tr.TestTransport()) |
| case func(*Transport): |
| o(tr) |
| case *Transport: |
| tr = o |
| case roundTripTestMode: |
| tt.mode = o |
| case nil: |
| default: |
| t.Fatalf("unsupported option %T", o) |
| } |
| } |
| tt.tr = tr |
| tt.tr1 = tr.TestTransport() |
| |
| go tt.accept() |
| |
| tt.tr.TestSetNewClientConnHook(func(cc *http2.ClientConn) { |
| nc, ok := cc.TestNetConn().(*synctestNetConn) |
| if !ok { |
| return |
| } |
| tt.addPending(nc.peer, cc, nil) |
| }) |
| |
| t.Cleanup(func() { |
| tt.li.Close() |
| synctest.Wait() |
| if len(tt.ccqueue) > 0 { |
| t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccqueue)) |
| } |
| }) |
| |
| return tt |
| } |
| |
| func (tt *testTransport) addPending(nc *synctestNetConn, cc *ClientConn, tc *testClientConn) { |
| tt.ccMu.Lock() |
| defer tt.ccMu.Unlock() |
| |
| for i, p := range tt.ccpending { |
| if p.nc != nc { |
| break |
| } |
| if p.tc != nil { |
| p.tc.cc = cc |
| } else if tc != nil { |
| tc.cc = p.cc |
| } else { |
| panic("found matching ccpending for conn with no tc") |
| } |
| tt.ccpending = slices.Delete(tt.ccpending, i, i+1) |
| return |
| } |
| |
| tt.ccpending = append(tt.ccpending, &testPendingClientConn{ |
| nc: nc, |
| cc: cc, |
| tc: tc, |
| }) |
| } |
| |
| func (tt *testTransport) accept() { |
| for { |
| nc, err := tt.li.Accept() |
| if err != nil { |
| return |
| } |
| tc := newTestClientConnFromNetConn(tt, nc) |
| tt.addPending(nc.(*synctestNetConn), nil, tc) |
| tt.ccqueue = append(tt.ccqueue, tc) |
| } |
| } |
| |
| func (tt *testTransport) hasConn() bool { |
| return len(tt.ccqueue) > 0 |
| } |
| |
| func (tt *testTransport) getConn() *testClientConn { |
| tt.t.Helper() |
| synctest.Wait() |
| tt.ccMu.Lock() |
| if len(tt.ccqueue) == 0 { |
| tt.ccMu.Unlock() |
| tt.t.Fatalf("no new ClientConns created; wanted one") |
| } |
| tc := tt.ccqueue[0] |
| tt.ccqueue = tt.ccqueue[1:] |
| tt.ccMu.Unlock() |
| tc.readClientPreface() |
| return tc |
| } |
| |
| func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { |
| ctx, cancel := context.WithCancel(req.Context()) |
| req = req.WithContext(ctx) |
| rt := &testRoundTrip{ |
| t: tt.t, |
| donec: make(chan struct{}), |
| cancel: cancel, |
| } |
| |
| go func() { |
| defer close(rt.donec) |
| switch tt.mode { |
| case roundTripXNetHTTP2: |
| rt.resp, rt.respErr = tt.tr.RoundTrip(req) |
| case roundTripNetHTTP: |
| rt.resp, rt.respErr = tt.tr1.RoundTrip(req) |
| } |
| }() |
| synctest.Wait() |
| |
| tt.t.Cleanup(func() { |
| rt.cancel() |
| if !rt.done() { |
| return |
| } |
| res, _ := rt.result() |
| if res != nil { |
| res.Body.Close() |
| } |
| }) |
| |
| return rt |
| } |
| |
| type nonblockingReader struct { |
| mu sync.Mutex |
| buf bytes.Buffer |
| err error |
| waitc chan struct{} |
| stopc chan struct{} |
| } |
| |
| func newNonblockingReader(reader io.Reader) *nonblockingReader { |
| r := &nonblockingReader{} |
| go func() { |
| buf := make([]byte, 1024) |
| for { |
| n, err := reader.Read(buf) |
| r.mu.Lock() |
| if n > 0 { |
| r.buf.Write(buf[:n]) |
| } |
| if err != nil { |
| r.err = err |
| } |
| if r.waitc != nil { |
| close(r.waitc) |
| r.waitc = nil |
| } |
| stopc := r.stopc |
| r.mu.Unlock() |
| if err != nil { |
| return |
| } |
| if stopc != nil { |
| <-stopc |
| } |
| } |
| }() |
| return r |
| } |
| |
| func (r *nonblockingReader) Read(p []byte) (n int, err error) { |
| synctest.Wait() |
| r.mu.Lock() |
| defer r.mu.Unlock() |
| n, err = r.buf.Read(p) |
| if err == io.EOF { |
| if r.err != nil { |
| err = r.err |
| } else { |
| err = errWouldBlock |
| } |
| } |
| return n, err |
| } |
| |
| func (r *nonblockingReader) waitForData(t testing.TB) time.Duration { |
| t.Helper() |
| synctest.Wait() |
| waitc := func() chan struct{} { |
| r.mu.Lock() |
| defer r.mu.Unlock() |
| if r.buf.Len() > 0 || r.err != nil { |
| return nil |
| } |
| if r.waitc == nil { |
| r.waitc = make(chan struct{}) |
| } |
| return r.waitc |
| }() |
| if waitc == nil { |
| return 0 |
| } |
| start := time.Now() |
| select { |
| case <-waitc: |
| case <-time.After(1 * time.Hour): |
| t.Fatalf("waited an hour for connection data, saw none") |
| } |
| return time.Since(start) |
| } |
| |
| func (r *nonblockingReader) stop() { |
| synctest.Wait() |
| if r.stopc != nil { |
| panic("stopping stopped reader") |
| } |
| r.stopc = make(chan struct{}) |
| } |
| |
| func (r *nonblockingReader) start() { |
| synctest.Wait() |
| if r.stopc == nil { |
| panic("starting started reader") |
| } |
| stopc := r.stopc |
| r.stopc = nil |
| close(stopc) |
| } |
| |
| var errWouldBlock = errors.New("would block") |