http2: make Transport return server's GOAWAY error back to the user
Updates golang/go#14627 (fixes once bundled into std)
Change-Id: Iae91d165df749e06549a25f9664ee416f115573f
Reviewed-on: https://go-review.googlesource.com/24560
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 631a04b..e1274b0 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2011,3 +2011,77 @@
time.Sleep(1 * time.Millisecond)
}
}
+
+// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
+// the Transport remember it and return it back to users (via
+// RoundTrip or request body reads) if needed (e.g. if the server
+// proceeds to close the TCP connection before the client gets its
+// response)
+func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
+ testTransportUsesGoAwayDebugError(t, false)
+}
+
+func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
+ testTransportUsesGoAwayDebugError(t, true)
+}
+
+func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
+ ct := newClientTester(t)
+ clientDone := make(chan struct{})
+
+ const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
+ const goAwayDebugData = "some debug data"
+
+ ct.client = func() error {
+ defer close(clientDone)
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ res, err := ct.tr.RoundTrip(req)
+ if failMidBody {
+ if err != nil {
+ return fmt.Errorf("unexpected client RoundTrip error: %v", err)
+ }
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ res.Body.Close()
+ }
+ want := GoAwayError{
+ LastStreamID: 0,
+ ErrCode: goAwayErrCode,
+ DebugData: goAwayDebugData,
+ }
+ if !reflect.DeepEqual(err, want) {
+ t.Errorf("RoundTrip error = %T: %#v, want %T (%#T)", err, err, want, want)
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ t.Logf("ReadFrame: %v", err)
+ return nil
+ }
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ continue
+ }
+ if failMidBody {
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ ct.fr.WriteGoAway(0, goAwayErrCode, []byte(goAwayDebugData))
+ ct.sc.Close()
+ <-clientDone
+ return nil
+ }
+ }
+ ct.run()
+}