http2: convert the remaining clientTester tests to testClientConn
Change-Id: Ia7f213346baff48504fef6dfdc112575a5459f35
Reviewed-on: https://go-review.googlesource.com/c/net/+/572380
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go
index 73ceefd..4237b14 100644
--- a/http2/clientconn_test.go
+++ b/http2/clientconn_test.go
@@ -14,6 +14,7 @@
"net"
"net/http"
"reflect"
+ "slices"
"testing"
"time"
@@ -209,6 +210,71 @@
}
}
+// wantUnorderedFrames reads frames from the conn until every condition in want has been satisfied.
+//
+// want is a list of func(*SomeFrame) bool.
+// wantUnorderedFrames will call each func with frames of the appropriate type
+// until the func returns true.
+// It calls t.Fatal if an unexpected frame is received (no func has that frame type,
+// or all funcs with that type have returned true), or if the conn runs out of frames
+// with unsatisfied funcs.
+//
+// Example:
+//
+// // Read a SETTINGS frame, and any number of DATA frames for a stream.
+// // The SETTINGS frame may appear anywhere in the sequence.
+// // The last DATA frame must indicate the end of the stream.
+// tc.wantUnorderedFrames(
+// func(f *SettingsFrame) bool {
+// return true
+// },
+// func(f *DataFrame) bool {
+// return f.StreamEnded()
+// },
+// )
+func (tc *testClientConn) wantUnorderedFrames(want ...any) {
+ tc.t.Helper()
+ want = slices.Clone(want)
+ seen := 0
+frame:
+ for seen < len(want) && !tc.t.Failed() {
+ fr := tc.readFrame()
+ if fr == nil {
+ break
+ }
+ for i, f := range want {
+ if f == nil {
+ continue
+ }
+ typ := reflect.TypeOf(f)
+ if typ.Kind() != reflect.Func ||
+ typ.NumIn() != 1 ||
+ typ.NumOut() != 1 ||
+ typ.Out(0) != reflect.TypeOf(true) {
+ tc.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
+ }
+ if typ.In(0) == reflect.TypeOf(fr) {
+ out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
+ if out[0].Bool() {
+ want[i] = nil
+ seen++
+ }
+ continue frame
+ }
+ }
+ tc.t.Errorf("got unexpected frame type %T", fr)
+ }
+ if seen < len(want) {
+ for _, f := range want {
+ if f == nil {
+ continue
+ }
+ tc.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
+ }
+ tc.t.Fatalf("did not see %v expected frame types", len(want)-seen)
+ }
+}
+
type wantHeader struct {
streamID uint32
endStream bool
@@ -401,6 +467,14 @@
tc.sync()
}
+func (tc *testClientConn) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
+ tc.t.Helper()
+ if err := tc.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
+ tc.t.Fatal(err)
+ }
+ tc.sync()
+}
+
// makeHeaderBlockFragment encodes headers in a form suitable for inclusion
// in a HEADERS or CONTINUATION frame.
//
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 18d4db3..855c107 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2724,122 +2724,75 @@
}
func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
- if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
- return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
- }
- res.Body.Close() // leaving 4999 bytes unread
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "5000",
+ ),
+ })
+ initialInflow := tc.inflowWindow(0)
- return nil
+ // Two cases:
+ // - Send one DATA frame with 5000 bytes.
+ // - Send two DATA frames with 1 and 4999 bytes each.
+ //
+ // In both cases, the client should consume one byte of data,
+ // refund that byte, then refund the following 4999 bytes.
+ //
+ // In the second case, the server waits for the client to reset the
+ // stream before sending the second DATA frame. This tests the case
+ // where the client receives a DATA frame after it has reset the stream.
+ const streamNotEnded = false
+ if oneDataFrame {
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000))
+ } else {
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1))
}
- ct.server = func() error {
- ct.greet()
- var hf *HeadersFrame
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- var ok bool
- hf, ok = f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- break
- }
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- initialInflow := ct.inflowWindow(0)
-
- // Two cases:
- // - Send one DATA frame with 5000 bytes.
- // - Send two DATA frames with 1 and 4999 bytes each.
- //
- // In both cases, the client should consume one byte of data,
- // refund that byte, then refund the following 4999 bytes.
- //
- // In the second case, the server waits for the client to reset the
- // stream before sending the second DATA frame. This tests the case
- // where the client receives a DATA frame after it has reset the stream.
- if oneDataFrame {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
- } else {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
- }
-
- wantRST := true
- wantWUF := true
- if !oneDataFrame {
- wantWUF = false // flow control update is small, and will not be sent
- }
- for wantRST || wantWUF {
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *RSTStreamFrame:
- if !wantRST {
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- if f.ErrCode != ErrCodeCancel {
- return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
- }
- wantRST = false
- case *WindowUpdateFrame:
- if !wantWUF {
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- if f.Increment != 5000 {
- return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
- }
- wantWUF = false
- default:
- return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
- }
- }
- if !oneDataFrame {
- ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
- f, err := ct.readNonSettingsFrame()
- if err != nil {
- return err
- }
- wuf, ok := f.(*WindowUpdateFrame)
- if !ok || wuf.Increment != 5000 {
- return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f))
- }
- }
- if err := ct.writeReadPing(); err != nil {
- return err
- }
- if got, want := ct.inflowWindow(0), initialInflow; got != want {
- return fmt.Errorf("connection flow tokens = %v, want %v", got, want)
- }
- return nil
+ res := rt.response()
+ if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
+ t.Fatalf("body read = %v, %v; want 1, nil", n, err)
}
- ct.run()
+ res.Body.Close() // leaving 4999 bytes unread
+ tc.sync()
+
+ sentAdditionalData := false
+ tc.wantUnorderedFrames(
+ func(f *RSTStreamFrame) bool {
+ if f.ErrCode != ErrCodeCancel {
+ t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
+ }
+ if !oneDataFrame {
+ // Send the remaining data now.
+ tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999))
+ sentAdditionalData = true
+ }
+ return true
+ },
+ func(f *WindowUpdateFrame) bool {
+ if !oneDataFrame && !sentAdditionalData {
+ t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
+ }
+ if f.Increment != 5000 {
+ t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
+ }
+ return true
+ },
+ )
+
+ if got, want := tc.inflowWindow(0), initialInflow; got != want {
+ t.Fatalf("connection flow tokens = %v, want %v", got, want)
+ }
}
// See golang.org/issue/16481
@@ -2855,199 +2808,124 @@
// Issue 16612: adjust flow control on open streams when transport
// receives SETTINGS with INITIAL_WINDOW_SIZE from server.
func TestTransportAdjustsFlowControl(t *testing.T) {
- ct := newClientTester(t)
- clientDone := make(chan struct{})
-
const bodySize = 1 << 20
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ // Don't write our SETTINGS yet.
- req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- res.Body.Close()
- return nil
- }
- ct.server = func() error {
- _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
+ body := tc.newRequestBody()
+ body.writeBytes(bodySize)
+ body.closeWithError(io.EOF)
- var gotBytes int64
- var sentSettings bool
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- return nil
- default:
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- }
- switch f := f.(type) {
- case *DataFrame:
- gotBytes += int64(len(f.Data()))
- // After we've got half the client's
- // initial flow control window's worth
- // of request body data, give it just
- // enough flow control to finish.
- if gotBytes >= initialWindowSize/2 && !sentSettings {
- sentSettings = true
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", body)
+ rt := tc.roundTrip(req)
- ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
- ct.fr.WriteWindowUpdate(0, bodySize)
- ct.fr.WriteSettingsAck()
- }
+ tc.wantFrameType(FrameHeaders)
- if f.StreamEnded() {
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- }
+ gotBytes := int64(0)
+ for {
+ f := testClientConnReadFrame[*DataFrame](tc)
+ gotBytes += int64(len(f.Data()))
+ // After we've got half the client's initial flow control window's worth
+ // of request body data, give it just enough flow control to finish.
+ if gotBytes >= initialWindowSize/2 {
+ break
}
}
- ct.run()
+
+ tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
+ tc.writeWindowUpdate(0, bodySize)
+ tc.writeSettingsAck()
+
+ tc.wantUnorderedFrames(
+ func(f *SettingsFrame) bool { return true },
+ func(f *DataFrame) bool {
+ gotBytes += int64(len(f.Data()))
+ return f.StreamEnded()
+ },
+ )
+
+ if gotBytes != bodySize {
+ t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize)
+ }
+
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
}
// See golang.org/issue/16556
func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
- unblockClient := make(chan bool, 1)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- defer res.Body.Close()
- <-unblockClient
- return nil
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ "content-length", "5000",
+ ),
+ })
+
+ initialConnWindow := tc.inflowWindow(0)
+ initialStreamWindow := tc.inflowWindow(rt.streamID())
+
+ pad := make([]byte, 5)
+ tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad)
+
+ // Padding flow control should have been returned.
+ if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want {
+ t.Errorf("conn inflow window = %v, want %v", got, want)
}
- ct.server = func() error {
- ct.greet()
-
- var hf *HeadersFrame
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- }
- var ok bool
- hf, ok = f.(*HeadersFrame)
- if !ok {
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- }
- break
- }
-
- initialConnWindow := ct.inflowWindow(0)
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- initialStreamWindow := ct.inflowWindow(hf.StreamID)
- pad := make([]byte, 5)
- ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
- if err := ct.writeReadPing(); err != nil {
- return err
- }
- // Padding flow control should have been returned.
- if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want {
- t.Errorf("conn inflow window = %v, want %v", got, want)
- }
- if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want {
- t.Errorf("stream inflow window = %v, want %v", got, want)
- }
- unblockClient <- true
- return nil
+ if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want {
+ t.Errorf("stream inflow window = %v, want %v", got, want)
}
- ct.run()
}
// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
// StreamError as a result of the response HEADERS
func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err == nil {
- res.Body.Close()
- return errors.New("unexpected successful GET")
- }
- want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
- if !reflect.DeepEqual(want, err) {
- t.Errorf("RoundTrip error = %#v; want %#v", err, want)
- }
- return nil
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ " content-type", "bogus",
+ ),
+ })
+
+ err := rt.err()
+ want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
+ if !reflect.DeepEqual(err, want) {
+ t.Fatalf("RoundTrip error = %#v; want %#v", err, want)
}
- ct.server = func() error {
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- return err
- }
-
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
-
- for {
- fr, err := ct.readFrame()
- if err != nil {
- return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
- }
- if _, ok := fr.(*SettingsFrame); ok {
- continue
- }
- if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
- t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
- }
- break
- }
-
- return nil
+ fr := testClientConnReadFrame[*RSTStreamFrame](tc)
+ if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol {
+ t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
}
- ct.run()
}
// byteAndEOFReader returns is in an io.Reader which reads one byte
@@ -3461,261 +3339,84 @@
}
}
-func testTransportPingWhenReading(t *testing.T, readIdleTimeout, deadline time.Duration, expectedPingCount int) {
- var pingCount int
- ct := newClientTester(t)
- ct.tr.ReadIdleTimeout = readIdleTimeout
-
- ctx, cancel := context.WithTimeout(context.Background(), deadline)
- defer cancel()
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip: %v", err)
- }
- defer res.Body.Close()
- if res.StatusCode != 200 {
- return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
- }
- _, err = ioutil.ReadAll(res.Body)
- if expectedPingCount == 0 && errors.Is(ctx.Err(), context.DeadlineExceeded) {
- return nil
- }
-
- cancel()
- return err
- }
-
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var streamID uint32
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-ctx.Done():
- // If the client's done, it
- // will have reported any
- // errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- streamID = f.StreamID
- case *PingFrame:
- pingCount++
- if pingCount == expectedPingCount {
- if err := ct.fr.WriteData(streamID, true, []byte("hello, this is last server data frame")); err != nil {
- return err
- }
- }
- if err := ct.fr.WritePing(true, f.Data); err != nil {
- return err
- }
- case *RSTStreamFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
- }
- ct.run()
-}
-
-func testClientMultipleDials(t *testing.T, client func(*Transport), server func(int, *clientTester)) {
- ln := newLocalListener(t)
- defer ln.Close()
-
- var (
- mu sync.Mutex
- count int
- conns []net.Conn
- )
- var wg sync.WaitGroup
- tr := &Transport{
- TLSClientConfig: tlsConfigInsecure,
- }
- tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
- mu.Lock()
- defer mu.Unlock()
- count++
- cc, err := net.Dial("tcp", ln.Addr().String())
- if err != nil {
- return nil, fmt.Errorf("dial error: %v", err)
- }
- conns = append(conns, cc)
- sc, err := ln.Accept()
- if err != nil {
- return nil, fmt.Errorf("accept error: %v", err)
- }
- conns = append(conns, sc)
- ct := &clientTester{
- t: t,
- tr: tr,
- cc: cc,
- sc: sc,
- fr: NewFramer(sc, sc),
- }
- wg.Add(1)
- go func(count int) {
- defer wg.Done()
- server(count, ct)
- }(count)
- return cc, nil
- }
-
- client(tr)
- tr.CloseIdleConnections()
- ln.Close()
- for _, c := range conns {
- c.Close()
- }
- wg.Wait()
-}
-
func TestTransportRetryAfterGOAWAY(t *testing.T) {
- client := func(tr *Transport) {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := tr.RoundTrip(req)
- if res != nil {
- res.Body.Close()
- if got := res.Header.Get("Foo"); got != "bar" {
- err = fmt.Errorf("foo header = %q; want bar", got)
- }
- }
- if err != nil {
- t.Errorf("RoundTrip: %v", err)
- }
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a GOAWAY.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeGoAway(0 /*max id*/, ErrCodeNo, nil)
+ if rt.done() {
+ t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
}
- server := func(count int, ct *clientTester) {
- switch count {
- case 1:
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- t.Errorf("server1 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server1 got %v", hf)
- if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
- t.Errorf("server1 failed writing GOAWAY: %v", err)
- return
- }
- case 2:
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- t.Errorf("server2 failed reading HEADERS: %v", err)
- return
- }
- t.Logf("server2 got %v", hf)
+ // Second attempt succeeds on a new connection.
+ tc = tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
- err = ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- if err != nil {
- t.Errorf("server2 failed writing response HEADERS: %v", err)
- }
- default:
- t.Errorf("unexpected number of dials")
- return
- }
- }
-
- testClientMultipleDials(t, client, server)
+ rt.wantStatus(200)
}
func TestTransportRetryAfterRefusedStream(t *testing.T) {
- clientDone := make(chan struct{})
- client := func(tr *Transport) {
- defer close(clientDone)
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- resp, err := tr.RoundTrip(req)
- if err != nil {
- t.Errorf("RoundTrip: %v", err)
- return
- }
- resp.Body.Close()
- if resp.StatusCode != 204 {
- t.Errorf("Status = %v; want 204", resp.StatusCode)
- return
- }
+ tt := newTestTransport(t)
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tt.roundTrip(req)
+
+ // First attempt: Server sends a RST_STREAM.
+ tc := tt.getConn()
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.wantFrameType(FrameSettings) // settings ACK
+ tc.writeRSTStream(1, ErrCodeRefusedStream)
+ if rt.done() {
+ t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying")
}
- server := func(_ int, ct *clientTester) {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- var count int
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- default:
- t.Error(err)
- }
- return
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- if !f.HeadersEnded() {
- t.Errorf("headers should have END_HEADERS be ended: %v", f)
- return
- }
- count++
- if count == 1 {
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- } else {
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- default:
- t.Errorf("Unexpected client frame %v", f)
- return
- }
- }
- }
+ // Second attempt succeeds on the same connection.
+ tc.wantHeaders(wantHeader{
+ streamID: 3,
+ endStream: true,
+ })
+ tc.writeSettings()
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 3,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "204",
+ ),
+ })
- testClientMultipleDials(t, client, server)
+ rt.wantStatus(204)
}
func TestTransportRetryHasLimit(t *testing.T) {
@@ -3765,67 +3466,34 @@
}
func TestTransportResponseDataBeforeHeaders(t *testing.T) {
- // This test use not valid response format.
- // Discarding logger output to not spam tests output.
- log.SetOutput(ioutil.Discard)
- defer log.SetOutput(os.Stderr)
+ // Discard log output complaining about protocol error.
+ log.SetOutput(io.Discard)
+ t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
- // First request is normal to ensure the check is per stream and not per connection.
- _, err := ct.tr.RoundTrip(req)
- if err != nil {
- return fmt.Errorf("RoundTrip expected no error, got: %v", err)
- }
- // Second request returns a DATA frame with no HEADERS.
- resp, err := ct.tr.RoundTrip(req)
- if err == nil {
- return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
- }
- if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
- return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
- }
- return nil
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ // First request is normal to ensure the check is per stream and not per connection.
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt1 := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt1.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
+
+ // Second request returns a DATA frame with no HEADERS.
+ rt2 := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ tc.writeData(rt2.streamID(), true, []byte("payload"))
+ if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol {
+ t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err)
}
- ct.server = func() error {
- ct.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err == io.EOF {
- return nil
- } else if err != nil {
- return err
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame, *RSTStreamFrame:
- case *HeadersFrame:
- switch f.StreamID {
- case 1:
- // Send a valid response to first request.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- case 3:
- ct.fr.WriteData(f.StreamID, true, []byte("payload"))
- }
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
- }
- ct.run()
}
func TestTransportMaxFrameReadSize(t *testing.T) {
@@ -3839,30 +3507,17 @@
maxReadFrameSize: 1024,
want: minMaxFrameSize,
}} {
- ct := newClientTester(t)
- ct.tr.MaxReadFrameSize = test.maxReadFrameSize
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
- ct.tr.RoundTrip(req)
- return nil
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxReadFrameSize = test.maxReadFrameSize
+ })
+
+ fr := testClientConnReadFrame[*SettingsFrame](tc)
+ got, ok := fr.Value(SettingMaxFrameSize)
+ if !ok {
+ t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want)
+ } else if got != test.want {
+ t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
}
- ct.server = func() error {
- defer ct.cc.(*net.TCPConn).Close()
- ct.greet()
- var got uint32
- ct.settings.ForeachSetting(func(s Setting) error {
- switch s.ID {
- case SettingMaxFrameSize:
- got = s.Val
- }
- return nil
- })
- if got != test.want {
- t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
- }
- return nil
- }
- ct.run()
}
}
@@ -3915,324 +3570,113 @@
func TestTransportRequestsStallAtServerLimit(t *testing.T) {
const maxConcurrent = 2
- greet := make(chan struct{}) // server sends initial SETTINGS frame
- gotRequest := make(chan struct{}) // server received a request
- clientDone := make(chan struct{})
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.StrictMaxConcurrentStreams = true
+ })
+ tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
cancelClientRequest := make(chan struct{})
- // Collect errors from goroutines.
- var wg sync.WaitGroup
- errs := make(chan error, 100)
- defer func() {
- wg.Wait()
- close(errs)
- for err := range errs {
- t.Error(err)
+ // Start maxConcurrent+2 requests.
+ // The server does not respond to any of them yet.
+ var rts []*testRoundTrip
+ for k := 0; k < maxConcurrent+2; k++ {
+ req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+ if k == maxConcurrent {
+ req.Cancel = cancelClientRequest
}
- }()
+ rt := tc.roundTrip(req)
+ rts = append(rts, rt)
- // We will send maxConcurrent+2 requests. This checker goroutine waits for the
- // following stages:
- // 1. The first maxConcurrent requests are received by the server.
- // 2. The client will cancel the next request
- // 3. The server is unblocked so it can service the first maxConcurrent requests
- // 4. The client will send the final request
- wg.Add(1)
- unblockClient := make(chan struct{})
- clientRequestCancelled := make(chan struct{})
- unblockServer := make(chan struct{})
- go func() {
- defer wg.Done()
- // Stage 1.
- for k := 0; k < maxConcurrent; k++ {
- <-gotRequest
+ if k < maxConcurrent {
+ // We are under the stream limit, so the client sends the request.
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{fmt.Sprintf("/%d", k)},
+ },
+ })
+ } else {
+ // We have reached the stream limit,
+ // so the client cannot send the request.
+ if fr := tc.readFrame(); fr != nil {
+ t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr)
+ }
}
- // Stage 2.
- close(unblockClient)
- <-clientRequestCancelled
- // Stage 3: give some time for the final RoundTrip call to be scheduled and
- // verify that the final request is not sent.
- time.Sleep(50 * time.Millisecond)
- select {
- case <-gotRequest:
- errs <- errors.New("last request did not stall")
- close(unblockServer)
- return
- default:
- }
- close(unblockServer)
- // Stage 4.
- <-gotRequest
- }()
- ct := newClientTester(t)
- ct.tr.StrictMaxConcurrentStreams = true
- ct.client = func() error {
- var wg sync.WaitGroup
- defer func() {
- wg.Wait()
- close(clientDone)
- ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- ct.cc.(*net.TCPConn).Close()
- }
- }()
- for k := 0; k < maxConcurrent+2; k++ {
- wg.Add(1)
- go func(k int) {
- defer wg.Done()
- // Don't send the second request until after receiving SETTINGS from the server
- // to avoid a race where we use the default SettingMaxConcurrentStreams, which
- // is much larger than maxConcurrent. We have to send the first request before
- // waiting because the first request triggers the dial and greet.
- if k > 0 {
- <-greet
- }
- // Block until maxConcurrent requests are sent before sending any more.
- if k >= maxConcurrent {
- <-unblockClient
- }
- body := newStaticCloseChecker("")
- req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
- if k == maxConcurrent {
- // This request will be canceled.
- req.Cancel = cancelClientRequest
- close(cancelClientRequest)
- _, err := ct.tr.RoundTrip(req)
- close(clientRequestCancelled)
- if err == nil {
- errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
- return
- }
- } else {
- resp, err := ct.tr.RoundTrip(req)
- if err != nil {
- errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
- return
- }
- ioutil.ReadAll(resp.Body)
- resp.Body.Close()
- if resp.StatusCode != 204 {
- errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
- return
- }
- }
- if err := body.isClosed(); err != nil {
- errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
- }
- }(k)
- }
- return nil
- }
-
- ct.server = func() error {
- var wg sync.WaitGroup
- defer wg.Wait()
-
- ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
-
- // Server write loop.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- writeResp := make(chan uint32, maxConcurrent+1)
-
- wg.Add(1)
- go func() {
- defer wg.Done()
- <-unblockServer
- for id := range writeResp {
- buf.Reset()
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: id,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- }()
-
- // Server read loop.
- var nreq int
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it will have reported any errors on its side.
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame:
- case *SettingsFrame:
- // Wait for the client SETTINGS ack until ending the greet.
- close(greet)
- case *HeadersFrame:
- if !f.HeadersEnded() {
- return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
- }
- gotRequest <- struct{}{}
- nreq++
- writeResp <- f.StreamID
- if nreq == maxConcurrent+1 {
- close(writeResp)
- }
- case *DataFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
+ if rt.done() {
+ t.Fatalf("rt %v done", k)
}
}
- ct.run()
+ // Cancel the maxConcurrent'th request.
+ // The request should fail.
+ close(cancelClientRequest)
+ tc.sync()
+ if err := rts[maxConcurrent].err(); err == nil {
+ t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent)
+ }
+
+ // No requests should be complete, except for the canceled one.
+ for i, rt := range rts {
+ if i != maxConcurrent && rt.done() {
+ t.Fatalf("RoundTrip(%d) is done, but should not be", i)
+ }
+ }
+
+ // Server responds to a request, unblocking the last one.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rts[0].streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ tc.wantHeaders(wantHeader{
+ streamID: rts[maxConcurrent+1].streamID(),
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)},
+ },
+ })
+ rts[0].wantStatus(200)
}
func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
- ct := newClientTester(t)
var reqSize, resSize uint32 = 8192, 16384
- ct.tr.MaxDecoderHeaderTableSize = reqSize
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- cc, err := ct.tr.NewClientConn(ct.cc)
- if err != nil {
- return err
- }
- _, err = cc.RoundTrip(req)
- if err != nil {
- return err
- }
- if got, want := cc.peerMaxHeaderTableSize, resSize; got != want {
- return fmt.Errorf("peerHeaderTableSize = %d, want %d", got, want)
- }
- return nil
- }
- ct.server = func() error {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- ct.t.Fatalf("wanted client settings frame; got %v", f)
- _ = sf // stash it away?
- }
- var found bool
- err = sf.ForeachSetting(func(s Setting) error {
- if s.ID == SettingHeaderTableSize {
- found = true
- if got, want := s.Val, reqSize; got != want {
- return fmt.Errorf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", got, want)
- }
- }
- return nil
- })
- if err != nil {
- return err
- }
- if !found {
- return fmt.Errorf("missing SETTINGS_HEADER_TABLE_SIZE setting")
- }
- if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, resSize}); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxDecoderHeaderTableSize = reqSize
+ })
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ fr := testClientConnReadFrame[*SettingsFrame](tc)
+ if v, ok := fr.Value(SettingHeaderTableSize); !ok {
+ t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting")
+ } else if v != reqSize {
+ t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize)
}
- ct.run()
+
+ tc.writeSettings(Setting{SettingHeaderTableSize, resSize})
+ if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want {
+ t.Fatalf("peerHeaderTableSize = %d, want %d", got, want)
+ }
}
func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
- ct := newClientTester(t)
var peerAdvertisedMaxHeaderTableSize uint32 = 16384
- ct.tr.MaxEncoderHeaderTableSize = 8192
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- cc, err := ct.tr.NewClientConn(ct.cc)
- if err != nil {
- return err
- }
- _, err = cc.RoundTrip(req)
- if err != nil {
- return err
- }
- if got, want := cc.henc.MaxDynamicTableSize(), ct.tr.MaxEncoderHeaderTableSize; got != want {
- return fmt.Errorf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
- }
- return nil
- }
- ct.server = func() error {
- buf := make([]byte, len(ClientPreface))
- _, err := io.ReadFull(ct.sc, buf)
- if err != nil {
- return fmt.Errorf("reading client preface: %v", err)
- }
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- sf, ok := f.(*SettingsFrame)
- if !ok {
- ct.t.Fatalf("wanted client settings frame; got %v", f)
- _ = sf // stash it away?
- }
- if err := ct.fr.WriteSettings(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize}); err != nil {
- ct.t.Fatal(err)
- }
- if err := ct.fr.WriteSettingsAck(); err != nil {
- ct.t.Fatal(err)
- }
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.MaxEncoderHeaderTableSize = 8192
+ })
+ tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize})
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- return nil
- }
- }
+ if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want {
+ t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
}
- ct.run()
}
func TestAuthorityAddr(t *testing.T) {
@@ -4316,40 +3760,24 @@
// Issue 18891: make sure Request.Body == NoBody means no DATA frame
// is ever sent, even if empty.
func TestTransportNoBodyMeansNoDATA(t *testing.T) {
- ct := newClientTester(t)
+ tc := newTestClientConn(t)
+ tc.greet()
- unblockClient := make(chan bool)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
+ rt := tc.roundTrip(req)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
- ct.tr.RoundTrip(req)
- <-unblockClient
- return nil
+ tc.wantHeaders(wantHeader{
+ streamID: rt.streamID(),
+ endStream: true, // END_STREAM should be set when body is http.NoBody
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ },
+ })
+ if fr := tc.readFrame(); fr != nil {
+ t.Fatalf("unexpected frame after headers: %v", fr)
}
- ct.server = func() error {
- defer close(unblockClient)
- defer ct.cc.(*net.TCPConn).Close()
- ct.greet()
-
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
- }
- switch f := f.(type) {
- default:
- return fmt.Errorf("Got %T; want HeadersFrame", f)
- case *WindowUpdateFrame, *SettingsFrame:
- continue
- case *HeadersFrame:
- if !f.StreamEnded() {
- return fmt.Errorf("got headers frame without END_STREAM")
- }
- return nil
- }
- }
- }
- ct.run()
}
func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
@@ -4428,41 +3856,22 @@
// Verify transport doesn't crash when receiving bogus response lacking a :status header.
// Issue 22880.
func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
- ct := newClientTester(t)
- ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- _, err := ct.tr.RoundTrip(req)
- const substr = "malformed response from server: missing status pseudo header"
- if !strings.Contains(fmt.Sprint(err), substr) {
- return fmt.Errorf("RoundTrip error = %v; want substring %q", err, substr)
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
+ tc := newTestClientConn(t)
+ tc.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
- switch f := f.(type) {
- case *HeadersFrame:
- enc.WriteField(hpack.HeaderField{Name: "content-type", Value: "text/html"}) // no :status header
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: false, // we'll send some DATA to try to crash the transport
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(f.StreamID, true, []byte("payload"))
- return nil
- }
- }
- }
- ct.run()
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false, // we'll send some DATA to try to crash the transport
+ BlockFragment: tc.makeHeaderBlockFragment(
+ "content-type", "text/html", // no :status header
+ ),
+ })
+ tc.writeData(rt.streamID(), true, []byte("payload"))
}
func BenchmarkClientRequestHeaders(b *testing.B) {
@@ -4810,95 +4219,42 @@
}
func testTransportBodyReadError(t *testing.T, body []byte) {
- if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
- // So far we've only seen this be flaky on Windows and Plan 9,
- // perhaps due to TCP behavior on shutdowns while
- // unread data is in flight. This test should be
- // fixed, but a skip is better than annoying people
- // for now.
- t.Skipf("skipping flaky test on %s; https://golang.org/issue/31260", runtime.GOOS)
- }
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.greet()
- checkNoStreams := func() error {
- cp, ok := ct.tr.connPool().(*clientConnPool)
- if !ok {
- return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
- }
- cp.mu.Lock()
- defer cp.mu.Unlock()
- conns, ok := cp.conns["dummy.tld:443"]
- if !ok {
- return fmt.Errorf("missing connection")
- }
- if len(conns) != 1 {
- return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
- }
- if activeStreams(conns[0]) != 0 {
- return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
- }
- return nil
- }
- bodyReadError := errors.New("body read error")
- body := &errReader{body, bodyReadError}
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- _, err = ct.tr.RoundTrip(req)
- if err != bodyReadError {
- return fmt.Errorf("err = %v; want %v", err, bodyReadError)
- }
- if err = checkNoStreams(); err != nil {
- return err
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
- var receivedBody []byte
- var resetCount int
- for {
- f, err := ct.fr.ReadFrame()
- t.Logf("server: ReadFrame = %v, %v", f, err)
- if err != nil {
- select {
- case <-clientDone:
- // If the client's done, it
- // will have reported any
- // errors on its side.
- if bytes.Compare(receivedBody, body) != 0 {
- return fmt.Errorf("body: %q; expected %q", receivedBody, body)
- }
- if resetCount != 1 {
- return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
- }
- return nil
- default:
- return err
- }
- }
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- case *DataFrame:
- receivedBody = append(receivedBody, f.Data()...)
- case *RSTStreamFrame:
- resetCount++
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
+ bodyReadError := errors.New("body read error")
+ b := tc.newRequestBody()
+ b.Write(body)
+ b.closeWithError(bodyReadError)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", b)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ var receivedBody []byte
+readFrames:
+ for {
+ switch f := tc.readFrame().(type) {
+ case *DataFrame:
+ receivedBody = append(receivedBody, f.Data()...)
+ case *RSTStreamFrame:
+ break readFrames
+ default:
+ t.Fatalf("unexpected frame: %v", f)
+ case nil:
+ t.Fatalf("transport is idle, want RST_STREAM")
}
}
- ct.run()
+ if !bytes.Equal(receivedBody, body) {
+ t.Fatalf("body: %q; expected %q", receivedBody, body)
+ }
+
+ if err := rt.err(); err != bodyReadError {
+ t.Fatalf("err = %v; want %v", err, bodyReadError)
+ }
+
+ if got := activeStreams(tc.cc); got != 0 {
+ t.Fatalf("active streams count: %v; want 0", got)
+ }
}
func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
@@ -4911,59 +4267,18 @@
const reqBody = "some request body"
const resBody = "some response body"
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- if runtime.GOOS == "plan9" {
- // CloseWrite not supported on Plan 9; Issue 17906
- defer ct.cc.(*net.TCPConn).Close()
- }
- body := strings.NewReader(reqBody)
- req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
- if err != nil {
- return err
- }
- _, err = ct.tr.RoundTrip(req)
- if err != nil {
- return err
- }
- return nil
- }
- ct.server = func() error {
- ct.greet()
+ tc := newTestClientConn(t)
+ tc.greet()
- for {
- f, err := ct.fr.ReadFrame()
- if err != nil {
- return err
- }
+ body := strings.NewReader(reqBody)
+ req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
+ tc.roundTrip(req)
- switch f := f.(type) {
- case *WindowUpdateFrame, *SettingsFrame:
- case *HeadersFrame:
- case *DataFrame:
- if !f.StreamEnded() {
- ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
- return fmt.Errorf("data frame without END_STREAM %v", f)
- }
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.Header().StreamID,
- EndHeaders: true,
- EndStream: false,
- BlockFragment: buf.Bytes(),
- })
- ct.fr.WriteData(f.StreamID, true, []byte(resBody))
- return nil
- case *RSTStreamFrame:
- default:
- return fmt.Errorf("Unexpected client frame %v", f)
- }
- }
+ tc.wantFrameType(FrameHeaders)
+ f := testClientConnReadFrame[*DataFrame](tc)
+ if !f.StreamEnded() {
+ t.Fatalf("data frame without END_STREAM %v", f)
}
- ct.run()
}
type chunkReader struct {
@@ -5737,39 +5052,27 @@
}
func TestTransportTimeoutServerHangs(t *testing.T) {
- clientDone := make(chan struct{})
- ct := newClientTester(t)
- ct.client = func() error {
- defer ct.cc.(*net.TCPConn).CloseWrite()
- defer close(clientDone)
+ tc := newTestClientConn(t)
+ tc.greet()
- req, err := http.NewRequest("PUT", "https://dummy.tld/", nil)
- if err != nil {
- return err
- }
+ ctx, cancel := context.WithCancel(context.Background())
+ req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
- ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
- defer cancel()
- req = req.WithContext(ctx)
- req.Header.Add("Big", strings.Repeat("a", 1<<20))
- _, err = ct.tr.RoundTrip(req)
- if err == nil {
- return errors.New("error should not be nil")
- }
- if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
- return fmt.Errorf("error should be a net error timeout: %v", err)
- }
- return nil
+ tc.wantFrameType(FrameHeaders)
+ tc.advance(5 * time.Second)
+ if f := tc.readFrame(); f != nil {
+ t.Fatalf("unexpected frame: %v", f)
}
- ct.server = func() error {
- ct.greet()
- select {
- case <-time.After(5 * time.Second):
- case <-clientDone:
- }
- return nil
+ if rt.done() {
+ t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned")
}
- ct.run()
+
+ cancel()
+ tc.sync()
+ if rt.err() != context.Canceled {
+ t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err())
+ }
}
func TestTransportContentLengthWithoutBody(t *testing.T) {
@@ -5962,20 +5265,6 @@
testTransportClosesConnAfterGoAway(t, 1)
}
-type closeOnceConn struct {
- net.Conn
- closed uint32
-}
-
-var errClosed = errors.New("Close of closed connection")
-
-func (c *closeOnceConn) Close() error {
- if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
- return c.Conn.Close()
- }
- return errClosed
-}
-
// testTransportClosesConnAfterGoAway verifies that the transport
// closes a connection after reading a GOAWAY from it.
//
@@ -5983,53 +5272,35 @@
// When 0, the transport (unsuccessfully) retries the request (stream 1);
// when 1, the transport reads the response after receiving the GOAWAY.
func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) {
- ct := newClientTester(t)
- ct.cc = &closeOnceConn{Conn: ct.cc}
+ tc := newTestClientConn(t)
+ tc.greet()
- var wg sync.WaitGroup
- wg.Add(1)
- ct.client = func() error {
- defer wg.Done()
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
- res, err := ct.tr.RoundTrip(req)
- if err == nil {
- res.Body.Close()
- }
- if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
- t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
- }
- if err = ct.cc.Close(); err != errClosed {
- return fmt.Errorf("ct.cc.Close() = %v, want errClosed", err)
- }
- return nil
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+
+ tc.wantFrameType(FrameHeaders)
+ tc.writeGoAway(lastStream, ErrCodeNo, nil)
+
+ if lastStream > 0 {
+ // Send a valid response to first request.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
}
- ct.server = func() error {
- defer wg.Wait()
- ct.greet()
- hf, err := ct.firstHeaders()
- if err != nil {
- return fmt.Errorf("server failed reading HEADERS: %v", err)
- }
- if err := ct.fr.WriteGoAway(lastStream, ErrCodeNo, nil); err != nil {
- return fmt.Errorf("server failed writing GOAWAY: %v", err)
- }
- if lastStream > 0 {
- // Send a valid response to first request.
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: hf.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
- }
- return nil
+ tc.closeWrite(io.EOF)
+ err := rt.err()
+ if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
+ t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
}
-
- ct.run()
+ if !tc.netConnClosed {
+ t.Errorf("ClientConn did not close its net.Conn, expected it to")
+ }
}
type slowCloser struct {