http2: make Transport ignore 100-continue responses, add comprehensive tests

This makes the Transport ignore 100-continue responses from servers,
rather than get confused by them. This is good enough for
golang/go#13659. I filed golang/go#13851 to do better later, but
that's less important.

This CL also adds comprehensive tests for the 36 different ways that
frames can be arranged from servers when reading a response. That
exposed some bugs (now fixed), and even affected the http2 API: I'd
added a END_STREAM accessor on CONTINUATION frames, but it's not even
valid there.

I also renamed some confusing variables which sounded too similar.

Updates golang/go#13659
Updates golang/go#13851

Change-Id: I58868a27258981267f1b2043f711f50a42ec744a
Reviewed-on: https://go-review.googlesource.com/18370
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/http2/frame.go b/http2/frame.go
index d8c94fa..ec3f796 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -1026,10 +1026,6 @@
 	return &ContinuationFrame{fh, p}, nil
 }
 
-func (f *ContinuationFrame) StreamEnded() bool {
-	return f.FrameHeader.Flags.Has(FlagDataEndStream)
-}
-
 func (f *ContinuationFrame) HeaderBlockFragment() []byte {
 	f.checkValid()
 	return f.headerFragBuf
diff --git a/http2/transport.go b/http2/transport.go
index 9104270..485c823 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -166,8 +166,8 @@
 	done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
 
 	// owned by clientConnReadLoop:
-	headersDone  bool // got HEADERS w/ END_HEADERS
-	trailersDone bool // got second HEADERS frame w/ END_HEADERS
+	pastHeaders  bool // got HEADERS w/ END_HEADERS
+	pastTrailers bool // got second HEADERS frame w/ END_HEADERS
 
 	trailer    http.Header // accumulated trailers
 	resTrailer http.Header // client's Response.Trailer
@@ -923,9 +923,10 @@
 	hdec *hpack.Decoder
 
 	// Fields reset on each HEADERS:
-	nextRes      *http.Response
-	sawRegHeader bool  // saw non-pseudo header
-	reqMalformed error // non-nil once known to be malformed
+	nextRes              *http.Response
+	sawRegHeader         bool  // saw non-pseudo header
+	reqMalformed         error // non-nil once known to be malformed
+	lastHeaderEndsStream bool
 }
 
 // readLoop runs in its own goroutine and reads and dispatches frames.
@@ -1021,21 +1022,23 @@
 func (rl *clientConnReadLoop) processHeaders(f *HeadersFrame) error {
 	rl.sawRegHeader = false
 	rl.reqMalformed = nil
+	rl.lastHeaderEndsStream = f.StreamEnded()
 	rl.nextRes = &http.Response{
 		Proto:      "HTTP/2.0",
 		ProtoMajor: 2,
 		Header:     make(http.Header),
 	}
-	return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded(), f.StreamEnded())
+	return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
 }
 
 func (rl *clientConnReadLoop) processContinuation(f *ContinuationFrame) error {
-	return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded(), f.StreamEnded())
+	return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
 }
 
-func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, headersEnded, streamEnded bool) error {
+func (rl *clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error {
 	cc := rl.cc
-	cs := cc.streamByID(streamID, streamEnded)
+	streamEnded := rl.lastHeaderEndsStream
+	cs := cc.streamByID(streamID, streamEnded && finalFrag)
 	if cs == nil {
 		// We'd get here if we canceled a request while the
 		// server was mid-way through replying with its
@@ -1045,7 +1048,7 @@
 		// ignore it.
 		return nil
 	}
-	if cs.headersDone {
+	if cs.pastHeaders {
 		rl.hdec.SetEmitFunc(cs.onNewTrailerField)
 	} else {
 		rl.hdec.SetEmitFunc(rl.onNewHeaderField)
@@ -1054,23 +1057,26 @@
 	if err != nil {
 		return ConnectionError(ErrCodeCompression)
 	}
-	if err := rl.hdec.Close(); err != nil {
-		return ConnectionError(ErrCodeCompression)
+	if finalFrag {
+		if err := rl.hdec.Close(); err != nil {
+			return ConnectionError(ErrCodeCompression)
+		}
 	}
-	if !headersEnded {
+
+	if !finalFrag {
 		return nil
 	}
 
-	if !cs.headersDone {
-		cs.headersDone = true
+	if !cs.pastHeaders {
+		cs.pastHeaders = true
 	} else {
 		// We're dealing with trailers. (and specifically the
 		// final frame of headers)
-		if cs.trailersDone {
+		if cs.pastTrailers {
 			// Too many HEADERS frames for this stream.
 			return ConnectionError(ErrCodeProtocol)
 		}
-		cs.trailersDone = true
+		cs.pastTrailers = true
 		if !streamEnded {
 			// We expect that any header block fragment
 			// frame for trailers with END_HEADERS also
@@ -1089,6 +1095,13 @@
 
 	res := rl.nextRes
 
+	if res.StatusCode == 100 {
+		// Just skip 100-continue response headers for now.
+		// TODO: golang.org/issue/13851 for doing it properly.
+		cs.pastHeaders = false // do it all again
+		return nil
+	}
+
 	if !streamEnded || cs.req.Method == "HEAD" {
 		res.ContentLength = -1
 		if clens := res.Header["Content-Length"]; len(clens) == 1 {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 80f754b..7a392a9 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -575,10 +575,15 @@
 	}
 }
 
+func (ct *clientTester) cleanup() {
+	ct.tr.CloseIdleConnections()
+}
+
 func (ct *clientTester) run() {
 	errc := make(chan error, 2)
 	ct.start("client", errc, ct.client)
 	ct.start("server", errc, ct.server)
+	defer ct.cleanup()
 	for i := 0; i < 2; i++ {
 		if err := <-errc; err != nil {
 			ct.t.Error(err)
@@ -819,3 +824,181 @@
 		}
 	}
 }
+
+type headerType int
+
+const (
+	noHeader headerType = iota // omitted
+	oneHeader
+	splitHeader // broken into continuation on purpose
+)
+
+const (
+	f0 = noHeader
+	f1 = oneHeader
+	f2 = splitHeader
+	d0 = false
+	d1 = true
+)
+
+// Test all 36 combinations of response frame orders:
+//    (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
+// Generated by http://play.golang.org/p/SScqYKJYXd
+func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
+func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
+func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
+func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
+func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
+func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
+func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
+func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
+func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
+func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
+func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
+func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
+func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
+func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
+func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
+func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
+func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
+func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
+func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
+func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
+func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
+func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
+func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
+func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
+func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
+func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
+func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
+func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
+func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
+func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
+func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
+func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
+func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
+func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
+func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
+func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
+
+func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
+	const reqBody = "some request body"
+	const resBody = "some response body"
+
+	if resHeader == noHeader {
+		// TODO: test 100-continue followed by immediate
+		// server stream reset, without headers in the middle?
+		panic("invalid combination")
+	}
+
+	ct := newClientTester(t)
+	ct.client = func() error {
+		req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
+		if expect100Continue != noHeader {
+			req.Header.Set("Expect", "100-continue")
+		}
+		res, err := ct.tr.RoundTrip(req)
+		if err != nil {
+			return fmt.Errorf("RoundTrip: %v", err)
+		}
+		defer res.Body.Close()
+		if res.StatusCode != 200 {
+			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
+		}
+		slurp, err := ioutil.ReadAll(res.Body)
+		if err != nil {
+			return fmt.Errorf("Slurp: %v", err)
+		}
+		wantBody := resBody
+		if !withData {
+			wantBody = ""
+		}
+		if string(slurp) != wantBody {
+			return fmt.Errorf("body = %q; want %q", slurp, wantBody)
+		}
+		if trailers == noHeader {
+			if len(res.Trailer) > 0 {
+				t.Errorf("Trailer = %v; want none", res.Trailer)
+			}
+		} else {
+			want := http.Header{"Some-Trailer": {"some-value"}}
+			if !reflect.DeepEqual(res.Trailer, want) {
+				t.Errorf("Trailer = %v; want %v", res.Trailer, want)
+			}
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				return err
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *DataFrame:
+				// ignore for now.
+			case *HeadersFrame:
+				endStream := false
+				send := func(mode headerType) {
+					hbf := buf.Bytes()
+					switch mode {
+					case oneHeader:
+						ct.fr.WriteHeaders(HeadersFrameParam{
+							StreamID:      f.StreamID,
+							EndHeaders:    true,
+							EndStream:     endStream,
+							BlockFragment: hbf,
+						})
+					case splitHeader:
+						if len(hbf) < 2 {
+							panic("too small")
+						}
+						ct.fr.WriteHeaders(HeadersFrameParam{
+							StreamID:      f.StreamID,
+							EndHeaders:    false,
+							EndStream:     endStream,
+							BlockFragment: hbf[:1],
+						})
+						ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
+					default:
+						panic("bogus mode")
+					}
+				}
+				if expect100Continue != noHeader {
+					buf.Reset()
+					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
+					send(expect100Continue)
+				}
+				// Response headers (1+ frames; 1 or 2 in this test, but never 0)
+				{
+					buf.Reset()
+					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+					enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
+					enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
+					if trailers != noHeader {
+						enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
+					}
+					endStream = withData == false && trailers == noHeader
+					send(resHeader)
+				}
+				if withData {
+					endStream = trailers == noHeader
+					ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
+				}
+				if trailers != noHeader {
+					endStream = true
+					buf.Reset()
+					enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
+					send(trailers)
+				}
+				return nil
+			}
+		}
+	}
+	ct.run()
+}