http2: move HEADERS/CONTINUATION order checking into Framer

Removes state machine complication and duplication out of Server &
Transport and puts it into the Framer instead (where it's nicely
tested).

Also, for testing, start tracking the reason for errors. Later we'll
use it in GOAWAY frames' debug data too.

Change-Id: Ic933654a33edb62b4432c28fe09f7bfdb6f9b334
Reviewed-on: https://go-review.googlesource.com/18101
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
diff --git a/http2/errors.go b/http2/errors.go
index 1f8de65..f320b2c 100644
--- a/http2/errors.go
+++ b/http2/errors.go
@@ -75,3 +75,16 @@
 type goAwayFlowError struct{}
 
 func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
+
+// connErrorReason wraps a ConnectionError with an informative error about why it occurs.
+
+// Errors of this type are only returned by the frame parser functions
+// and converted into ConnectionError(ErrCodeProtocol).
+type connError struct {
+	Code   ErrCode
+	Reason string
+}
+
+func (e connError) Error() string {
+	return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
+}
diff --git a/http2/frame.go b/http2/frame.go
index dc33794..d8c94fa 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -255,6 +255,11 @@
 type Framer struct {
 	r         io.Reader
 	lastFrame Frame
+	errReason string
+
+	// lastHeaderStream is non-zero if the last frame was an
+	// unfinished HEADERS/CONTINUATION.
+	lastHeaderStream uint32
 
 	maxReadSize uint32
 	headerBuf   [frameHeaderLen]byte
@@ -271,13 +276,19 @@
 	wbuf []byte
 
 	// AllowIllegalWrites permits the Framer's Write methods to
-	// write frames that do not conform to the HTTP/2 spec.  This
+	// write frames that do not conform to the HTTP/2 spec. This
 	// permits using the Framer to test other HTTP/2
 	// implementations' conformance to the spec.
 	// If false, the Write methods will prefer to return an error
 	// rather than comply.
 	AllowIllegalWrites bool
 
+	// AllowIllegalReads permits the Framer's ReadFrame method
+	// to return non-compliant frames or frame orders.
+	// This is for testing and permits using the Framer to test
+	// other HTTP/2 implementations' conformance to the spec.
+	AllowIllegalReads bool
+
 	// TODO: track which type of frame & with which flags was sent
 	// last.  Then return an error (unless AllowIllegalWrites) if
 	// we're in the middle of a header block and a
@@ -394,12 +405,65 @@
 	}
 	f, err := typeFrameParser(fh.Type)(fh, payload)
 	if err != nil {
+		if ce, ok := err.(connError); ok {
+			return nil, fr.connError(ce.Code, ce.Reason)
+		}
 		return nil, err
 	}
-	fr.lastFrame = f
+	if err := fr.checkFrameOrder(f); err != nil {
+		return nil, err
+	}
 	return f, nil
 }
 
+// connError returns ConnectionError(code) but first
+// stashes away a public reason to the caller can optionally relay it
+// to the peer before hanging up on them. This might help others debug
+// their implementations.
+func (fr *Framer) connError(code ErrCode, reason string) error {
+	fr.errReason = reason
+	return ConnectionError(code)
+}
+
+// checkFrameOrder reports an error if f is an invalid frame to return
+// next from ReadFrame. Mostly it checks whether HEADERS and
+// CONTINUATION frames are contiguous.
+func (fr *Framer) checkFrameOrder(f Frame) error {
+	last := fr.lastFrame
+	fr.lastFrame = f
+	if fr.AllowIllegalReads {
+		return nil
+	}
+
+	fh := f.Header()
+	if fr.lastHeaderStream != 0 {
+		if fh.Type != FrameContinuation {
+			return fr.connError(ErrCodeProtocol,
+				fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
+					fh.Type, fh.StreamID,
+					last.Header().Type, fr.lastHeaderStream))
+		}
+		if fh.StreamID != fr.lastHeaderStream {
+			return fr.connError(ErrCodeProtocol,
+				fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d",
+					fh.StreamID, fr.lastHeaderStream))
+		}
+	} else if fh.Type == FrameContinuation {
+		return fr.connError(ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID))
+	}
+
+	switch fh.Type {
+	case FrameHeaders, FrameContinuation:
+		if fh.Flags.Has(FlagHeadersEndHeaders) {
+			fr.lastHeaderStream = 0
+		} else {
+			fr.lastHeaderStream = fh.StreamID
+		}
+	}
+
+	return nil
+}
+
 // A DataFrame conveys arbitrary, variable-length sequences of octets
 // associated with a stream.
 // See http://http2.github.io/http2-spec/#rfc.section.6.1
@@ -428,7 +492,7 @@
 		// field is 0x0, the recipient MUST respond with a
 		// connection error (Section 5.4.1) of type
 		// PROTOCOL_ERROR.
-		return nil, ConnectionError(ErrCodeProtocol)
+		return nil, connError{ErrCodeProtocol, "DATA frame with stream ID 0"}
 	}
 	f := &DataFrame{
 		FrameHeader: fh,
@@ -446,7 +510,7 @@
 		// length of the frame payload, the recipient MUST
 		// treat this as a connection error.
 		// Filed: https://github.com/http2/http2-spec/issues/610
-		return nil, ConnectionError(ErrCodeProtocol)
+		return nil, connError{ErrCodeProtocol, "pad size larger than data payload"}
 	}
 	f.data = payload[:len(payload)-int(padSize)]
 	return f, nil
@@ -753,7 +817,7 @@
 		// is received whose stream identifier field is 0x0, the recipient MUST
 		// respond with a connection error (Section 5.4.1) of type
 		// PROTOCOL_ERROR.
-		return nil, ConnectionError(ErrCodeProtocol)
+		return nil, connError{ErrCodeProtocol, "HEADERS frame with stream ID 0"}
 	}
 	var padLength uint8
 	if fh.Flags.Has(FlagHeadersPadded) {
@@ -883,10 +947,10 @@
 
 func parsePriorityFrame(fh FrameHeader, payload []byte) (Frame, error) {
 	if fh.StreamID == 0 {
-		return nil, ConnectionError(ErrCodeProtocol)
+		return nil, connError{ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
 	}
 	if len(payload) != 5 {
-		return nil, ConnectionError(ErrCodeFrameSize)
+		return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
 	}
 	v := binary.BigEndian.Uint32(payload[:4])
 	streamID := v & 0x7fffffff // mask off high bit
@@ -956,6 +1020,9 @@
 }
 
 func parseContinuationFrame(fh FrameHeader, p []byte) (Frame, error) {
+	if fh.StreamID == 0 {
+		return nil, connError{ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
+	}
 	return &ContinuationFrame{fh, p}, nil
 }
 
diff --git a/http2/frame_test.go b/http2/frame_test.go
index 20027c5..281a2d1 100644
--- a/http2/frame_test.go
+++ b/http2/frame_test.go
@@ -6,6 +6,8 @@
 
 import (
 	"bytes"
+	"fmt"
+	"io"
 	"reflect"
 	"strings"
 	"testing"
@@ -264,6 +266,7 @@
 			t.Errorf("test %q: %v", tt.name, err)
 			continue
 		}
+		fr.AllowIllegalReads = true
 		f, err := fr.ReadFrame()
 		if err != nil {
 			t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
@@ -595,3 +598,138 @@
 		t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
 	}
 }
+
+// test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled.
+func TestReadFrameOrder(t *testing.T) {
+	head := func(f *Framer, id uint32, end bool) {
+		f.WriteHeaders(HeadersFrameParam{
+			StreamID:      id,
+			BlockFragment: []byte("foo"), // unused, but non-empty
+			EndHeaders:    end,
+		})
+	}
+	cont := func(f *Framer, id uint32, end bool) {
+		f.WriteContinuation(id, end, []byte("foo"))
+	}
+
+	tests := [...]struct {
+		name    string
+		w       func(*Framer)
+		atLeast int
+		wantErr string
+	}{
+		0: {
+			w: func(f *Framer) {
+				head(f, 1, true)
+			},
+		},
+		1: {
+			w: func(f *Framer) {
+				head(f, 1, true)
+				head(f, 2, true)
+			},
+		},
+		2: {
+			wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1",
+			w: func(f *Framer) {
+				head(f, 1, false)
+				head(f, 2, true)
+			},
+		},
+		3: {
+			wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1",
+			w: func(f *Framer) {
+				head(f, 1, false)
+			},
+		},
+		4: {
+			w: func(f *Framer) {
+				head(f, 1, false)
+				cont(f, 1, true)
+				head(f, 2, true)
+			},
+		},
+		5: {
+			wantErr: "got CONTINUATION for stream 2; expected stream 1",
+			w: func(f *Framer) {
+				head(f, 1, false)
+				cont(f, 2, true)
+				head(f, 2, true)
+			},
+		},
+		6: {
+			wantErr: "unexpected CONTINUATION for stream 1",
+			w: func(f *Framer) {
+				cont(f, 1, true)
+			},
+		},
+		7: {
+			wantErr: "unexpected CONTINUATION for stream 1",
+			w: func(f *Framer) {
+				cont(f, 1, false)
+			},
+		},
+		8: {
+			wantErr: "HEADERS frame with stream ID 0",
+			w: func(f *Framer) {
+				head(f, 0, true)
+			},
+		},
+		9: {
+			wantErr: "CONTINUATION frame with stream ID 0",
+			w: func(f *Framer) {
+				cont(f, 0, true)
+			},
+		},
+		10: {
+			wantErr: "unexpected CONTINUATION for stream 1",
+			atLeast: 5,
+			w: func(f *Framer) {
+				head(f, 1, false)
+				cont(f, 1, false)
+				cont(f, 1, false)
+				cont(f, 1, false)
+				cont(f, 1, true)
+				cont(f, 1, false)
+			},
+		},
+	}
+	for i, tt := range tests {
+		buf := new(bytes.Buffer)
+		f := NewFramer(buf, buf)
+		f.AllowIllegalWrites = true
+		tt.w(f)
+		f.WriteData(1, true, nil) // to test transition away from last step
+
+		var err error
+		n := 0
+		var log bytes.Buffer
+		for {
+			var got Frame
+			got, err = f.ReadFrame()
+			fmt.Fprintf(&log, "  read %v, %v\n", got, err)
+			if err != nil {
+				break
+			}
+			n++
+		}
+		if err == io.EOF {
+			err = nil
+		}
+		ok := tt.wantErr == ""
+		if ok && err != nil {
+			t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes())
+			continue
+		}
+		if !ok && err != ConnectionError(ErrCodeProtocol) {
+			t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes())
+			continue
+		}
+		if f.errReason != tt.wantErr {
+			t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errReason, tt.wantErr, log.Bytes())
+		}
+		if n < tt.atLeast {
+			t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes())
+		}
+	}
+}
diff --git a/http2/server.go b/http2/server.go
index e9b21ac..9326eb6 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -1015,18 +1015,6 @@
 	}
 }
 
-// curHeaderStreamID returns the stream ID of the header block we're
-// currently in the middle of reading. If this returns non-zero, the
-// next frame must be a CONTINUATION with this stream id.
-func (sc *serverConn) curHeaderStreamID() uint32 {
-	sc.serveG.check()
-	st := sc.req.stream
-	if st == nil {
-		return 0
-	}
-	return st.id
-}
-
 // processFrameFromReader processes the serve loop's read from readFrameCh from the
 // frame-reading goroutine.
 // processFrameFromReader returns whether the connection should be kept open.
@@ -1091,14 +1079,6 @@
 		sc.sawFirstSettings = true
 	}
 
-	if s := sc.curHeaderStreamID(); s != 0 {
-		if cf, ok := f.(*ContinuationFrame); !ok {
-			return ConnectionError(ErrCodeProtocol)
-		} else if cf.Header().StreamID != s {
-			return ConnectionError(ErrCodeProtocol)
-		}
-	}
-
 	switch f := f.(type) {
 	case *SettingsFrame:
 		return sc.processSettings(f)
@@ -1437,9 +1417,6 @@
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 	sc.serveG.check()
 	st := sc.streams[f.Header().StreamID]
-	if st == nil || sc.curHeaderStreamID() != st.id {
-		return ConnectionError(ErrCodeProtocol)
-	}
 	if st.gotTrailerHeader {
 		return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
 	}
diff --git a/http2/transport.go b/http2/transport.go
index 734fd56..781be9d 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -851,10 +851,6 @@
 	cc        *ClientConn
 	activeRes map[uint32]*clientStream // keyed by streamID
 
-	// continueStreamID is the stream ID we're waiting for
-	// continuation frames for.
-	continueStreamID uint32
-
 	hdec *hpack.Decoder
 
 	// Fields reset on each HEADERS:
@@ -924,21 +920,6 @@
 		}
 		cc.vlogf("Transport received %v: %#v", f.Header(), f)
 
-		streamID := f.Header().StreamID
-
-		_, isContinue := f.(*ContinuationFrame)
-		if isContinue {
-			if streamID != rl.continueStreamID {
-				cc.logf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, rl.continueStreamID)
-				return ConnectionError(ErrCodeProtocol)
-			}
-		} else if rl.continueStreamID != 0 {
-			// Continue frames need to be adjacent in the stream
-			// and we were in the middle of headers.
-			cc.logf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, rl.continueStreamID)
-			return ConnectionError(ErrCodeProtocol)
-		}
-
 		switch f := f.(type) {
 		case *HeadersFrame:
 			err = rl.processHeaders(f)
@@ -986,13 +967,12 @@
 	cc := rl.cc
 	cs := cc.streamByID(streamID, streamEnded)
 	if cs == nil {
-		// We could return a ConnectionError(ErrCodeProtocol)
-		// here, except that in the case of us canceling
-		// client requests, we may also delete from the
-		// streams map, in which case we forgot that we sent
-		// this request. So, just ignore any responses for
-		// now.  They might've been in-flight before the
-		// server got our RST_STREAM.
+		// We'd get here if we canceled a request while the
+		// server was mid-way through replying with its
+		// headers. (The case of a CONTINUATION arriving
+		// without HEADERS would be rejected earlier by the
+		// Framer). So if this was just something we canceled,
+		// ignore it.
 		return nil
 	}
 	if cs.headersDone {
@@ -1004,12 +984,12 @@
 	if err != nil {
 		return ConnectionError(ErrCodeCompression)
 	}
+	if err := rl.hdec.Close(); err != nil {
+		return ConnectionError(ErrCodeCompression)
+	}
 	if !headersEnded {
-		rl.continueStreamID = cs.ID
 		return nil
 	}
-	// HEADERS (or CONTINUATION) are now over.
-	rl.continueStreamID = 0
 
 	if !cs.headersDone {
 		cs.headersDone = true