http2: make Transport ignore 100-continue responses, add comprehensive tests
This makes the Transport ignore 100-continue responses from servers,
rather than get confused by them. This is good enough for
golang/go#13659. I filed golang/go#13851 to do better later, but
that's less important.
This CL also adds comprehensive tests for the 36 different ways that
frames can be arranged from servers when reading a response. That
exposed some bugs (now fixed), and even affected the http2 API: I'd
added a END_STREAM accessor on CONTINUATION frames, but it's not even
valid there.
I also renamed some confusing variables which sounded too similar.
Updates golang/go#13659
Updates golang/go#13851
Change-Id: I58868a27258981267f1b2043f711f50a42ec744a
Reviewed-on: https://go-review.googlesource.com/18370
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/http2/frame.go b/http2/frame.go
index d8c94fa..ec3f796 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -1026,10 +1026,6 @@
return &ContinuationFrame{fh, p}, nil
}
-func (f *ContinuationFrame) StreamEnded() bool {
- return f.FrameHeader.Flags.Has(FlagDataEndStream)
-}
-
func (f *ContinuationFrame) HeaderBlockFragment() []byte {
f.checkValid()
return f.headerFragBuf
diff --git a/http2/transport.go b/http2/transport.go
index 9104270..485c823 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -166,8 +166,8 @@
done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
// owned by clientConnReadLoop:
- headersDone bool // got HEADERS w/ END_HEADERS
- trailersDone bool // got second HEADERS frame w/ END_HEADERS
+ pastHeaders bool // got HEADERS w/ END_HEADERS
+ pastTrailers bool // got second HEADERS frame w/ END_HEADERS
trailer http.Header // accumulated trailers
resTrailer http.Header // client's Response.Trailer
@@ -923,9 +923,10 @@
hdec *hpack.Decoder
// Fields reset on each HEADERS:
- nextRes *http.Response
- sawRegHeader bool // saw non-pseudo header
- reqMalformed error // non-nil once known to be malformed
+ nextRes *http.Response
+ sawRegHeader bool // saw non-pseudo header
+ reqMalformed error // non-nil once known to be malformed
+ lastHeaderEndsStream bool
}
// readLoop runs in its own goroutine and reads and dispatches frames.
@@ -1021,21 +1022,23 @@
func (rl *clientConnReadLoop) processHeaders(f *HeadersFrame) error {
rl.sawRegHeader = false
rl.reqMalformed = nil
+ rl.lastHeaderEndsStream = f.StreamEnded()
rl.nextRes = &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: make(http.Header),
}
- return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded(), f.StreamEnded())
+ return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
}
func (rl *clientConnReadLoop) processContinuation(f *ContinuationFrame) error {
- return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded(), f.StreamEnded())
+ return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
}
-func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, headersEnded, streamEnded bool) error {
+func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error {
cc := rl.cc
- cs := cc.streamByID(streamID, streamEnded)
+ streamEnded := rl.lastHeaderEndsStream
+ cs := cc.streamByID(streamID, streamEnded && finalFrag)
if cs == nil {
// We'd get here if we canceled a request while the
// server was mid-way through replying with its
@@ -1045,7 +1048,7 @@
// ignore it.
return nil
}
- if cs.headersDone {
+ if cs.pastHeaders {
rl.hdec.SetEmitFunc(cs.onNewTrailerField)
} else {
rl.hdec.SetEmitFunc(rl.onNewHeaderField)
@@ -1054,23 +1057,26 @@
if err != nil {
return ConnectionError(ErrCodeCompression)
}
- if err := rl.hdec.Close(); err != nil {
- return ConnectionError(ErrCodeCompression)
+ if finalFrag {
+ if err := rl.hdec.Close(); err != nil {
+ return ConnectionError(ErrCodeCompression)
+ }
}
- if !headersEnded {
+
+ if !finalFrag {
return nil
}
- if !cs.headersDone {
- cs.headersDone = true
+ if !cs.pastHeaders {
+ cs.pastHeaders = true
} else {
// We're dealing with trailers. (and specifically the
// final frame of headers)
- if cs.trailersDone {
+ if cs.pastTrailers {
// Too many HEADERS frames for this stream.
return ConnectionError(ErrCodeProtocol)
}
- cs.trailersDone = true
+ cs.pastTrailers = true
if !streamEnded {
// We expect that any header block fragment
// frame for trailers with END_HEADERS also
@@ -1089,6 +1095,13 @@
res := rl.nextRes
+ if res.StatusCode == 100 {
+ // Just skip 100-continue response headers for now.
+ // TODO: golang.org/issue/13851 for doing it properly.
+ cs.pastHeaders = false // do it all again
+ return nil
+ }
+
if !streamEnded || cs.req.Method == "HEAD" {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 80f754b..7a392a9 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -575,10 +575,15 @@
}
}
+func (ct *clientTester) cleanup() {
+ ct.tr.CloseIdleConnections()
+}
+
func (ct *clientTester) run() {
errc := make(chan error, 2)
ct.start("client", errc, ct.client)
ct.start("server", errc, ct.server)
+ defer ct.cleanup()
for i := 0; i < 2; i++ {
if err := <-errc; err != nil {
ct.t.Error(err)
@@ -819,3 +824,181 @@
}
}
}
+
+type headerType int
+
+const (
+ noHeader headerType = iota // omitted
+ oneHeader
+ splitHeader // broken into continuation on purpose
+)
+
+const (
+ f0 = noHeader
+ f1 = oneHeader
+ f2 = splitHeader
+ d0 = false
+ d1 = true
+)
+
+// Test all 36 combinations of response frame orders:
+// (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
+// Generated by http://play.golang.org/p/SScqYKJYXd
+func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
+func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
+func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
+func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
+func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
+func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
+func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
+func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
+func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
+func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
+func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
+func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
+func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
+func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
+func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
+func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
+func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
+func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
+func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
+func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
+func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
+func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
+func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
+func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
+func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
+func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
+func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
+func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
+func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
+func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
+func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
+func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
+func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
+func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
+func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
+func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
+
+func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
+ const reqBody = "some request body"
+ const resBody = "some response body"
+
+ if resHeader == noHeader {
+ // TODO: test 100-continue followed by immediate
+ // server stream reset, without headers in the middle?
+ panic("invalid combination")
+ }
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
+ if expect100Continue != noHeader {
+ req.Header.Set("Expect", "100-continue")
+ }
+ res, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ return fmt.Errorf("RoundTrip: %v", err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ return fmt.Errorf("status code = %v; want 200", res.StatusCode)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("Slurp: %v", err)
+ }
+ wantBody := resBody
+ if !withData {
+ wantBody = ""
+ }
+ if string(slurp) != wantBody {
+ return fmt.Errorf("body = %q; want %q", slurp, wantBody)
+ }
+ if trailers == noHeader {
+ if len(res.Trailer) > 0 {
+ t.Errorf("Trailer = %v; want none", res.Trailer)
+ }
+ } else {
+ want := http.Header{"Some-Trailer": {"some-value"}}
+ if !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Trailer = %v; want %v", res.Trailer, want)
+ }
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ return err
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *DataFrame:
+ // ignore for now.
+ case *HeadersFrame:
+ endStream := false
+ send := func(mode headerType) {
+ hbf := buf.Bytes()
+ switch mode {
+ case oneHeader:
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: endStream,
+ BlockFragment: hbf,
+ })
+ case splitHeader:
+ if len(hbf) < 2 {
+ panic("too small")
+ }
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: false,
+ EndStream: endStream,
+ BlockFragment: hbf[:1],
+ })
+ ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
+ default:
+ panic("bogus mode")
+ }
+ }
+ if expect100Continue != noHeader {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
+ send(expect100Continue)
+ }
+ // Response headers (1+ frames; 1 or 2 in this test, but never 0)
+ {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
+ enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
+ if trailers != noHeader {
+ enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
+ }
+ endStream = withData == false && trailers == noHeader
+ send(resHeader)
+ }
+ if withData {
+ endStream = trailers == noHeader
+ ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
+ }
+ if trailers != noHeader {
+ endStream = true
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
+ send(trailers)
+ }
+ return nil
+ }
+ }
+ }
+ ct.run()
+}