http2: discard more frames after GOAWAY

After sending a GOAWAY with NO_ERROR, we should discard all frames
for streams with larger identifiers than the last stream identifier
in the GOAWAY frame. We weren't discarding RST_STREAM frames, which
could cause us to incorrectly detect a protocol error when handling
a RST_STREAM for a discarded stream.

Hoist post-GOAWAY frame discarding higher in the loop rather than
handling it on a per-frame-type basis.

We are also supposed to count discarded DATA frames against
connection-level flow control, possibly sending WINDOW_UPDATE
messages to return the flow control. We weren't doing this;
this is now fixed.

Fixes golang/go#55846

Change-Id: I7603a529c00b8637e648eee9cc4608fb5fa5199b
Reviewed-on: https://go-review.googlesource.com/c/net/+/434909
Reviewed-by: Heschi Kreinick <heschi@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: LI ZHEN <mr.imuz@gmail.com>
Reviewed-by: Antonio Ojea <antonio.ojea.garcia@gmail.com>
diff --git a/http2/server.go b/http2/server.go
index 91f8437..2bf2989 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -1459,6 +1459,22 @@
 		sc.sawFirstSettings = true
 	}
 
+	// Discard frames for streams initiated after the identified last
+	// stream sent in a GOAWAY, or all frames after sending an error.
+	// We still need to return connection-level flow control for DATA frames.
+	// RFC 9113 Section 6.8.
+	if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
+
+		if f, ok := f.(*DataFrame); ok {
+			if sc.inflow.available() < int32(f.Length) {
+				return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
+			}
+			sc.inflow.take(int32(f.Length))
+			sc.sendWindowUpdate(nil) // conn-level
+		}
+		return nil
+	}
+
 	switch f := f.(type) {
 	case *SettingsFrame:
 		return sc.processSettings(f)
@@ -1501,9 +1517,6 @@
 		// PROTOCOL_ERROR."
 		return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
 	}
-	if sc.inGoAway && sc.goAwayCode != ErrCodeNo {
-		return nil
-	}
 	sc.writeFrame(FrameWriteRequest{write: writePingAck{f}})
 	return nil
 }
@@ -1686,16 +1699,6 @@
 func (sc *serverConn) processData(f *DataFrame) error {
 	sc.serveG.check()
 	id := f.Header().StreamID
-	if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || id > sc.maxClientStreamID) {
-		// Discard all DATA frames if the GOAWAY is due to an
-		// error, or:
-		//
-		// Section 6.8: After sending a GOAWAY frame, the sender
-		// can discard frames for streams initiated by the
-		// receiver with identifiers higher than the identified
-		// last stream.
-		return nil
-	}
 
 	data := f.Data()
 	state, st := sc.state(id)
@@ -1847,10 +1850,6 @@
 func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
 	sc.serveG.check()
 	id := f.StreamID
-	if sc.inGoAway {
-		// Ignore.
-		return nil
-	}
 	// http://tools.ietf.org/html/rfc7540#section-5.1.1
 	// Streams initiated by a client MUST use odd-numbered stream
 	// identifiers. [...] An endpoint that receives an unexpected
@@ -2021,9 +2020,6 @@
 }
 
 func (sc *serverConn) processPriority(f *PriorityFrame) error {
-	if sc.inGoAway {
-		return nil
-	}
 	if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
 		return err
 	}
diff --git a/http2/server_test.go b/http2/server_test.go
index 9721bef..2ce0d8e 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -3994,7 +3994,6 @@
 // and the connection still completing.
 func TestServerHandlerConnectionClose(t *testing.T) {
 	unblockHandler := make(chan bool, 1)
-	defer close(unblockHandler) // backup; in case of errors
 	testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
 		w.Header().Set("Connection", "close")
 		w.Header().Set("Foo", "bar")
@@ -4002,6 +4001,7 @@
 		<-unblockHandler
 		return nil
 	}, func(st *serverTester) {
+		defer close(unblockHandler) // backup; in case of errors
 		st.writeHeaders(HeadersFrameParam{
 			StreamID:      1,
 			BlockFragment: st.encodeHeader(),
@@ -4010,6 +4010,7 @@
 		})
 		var sawGoAway bool
 		var sawRes bool
+		var sawWindowUpdate bool
 		for {
 			f, err := st.readFrame()
 			if err == io.EOF {
@@ -4021,10 +4022,29 @@
 			switch f := f.(type) {
 			case *GoAwayFrame:
 				sawGoAway = true
-				unblockHandler <- true
 				if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo {
 					t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f))
 				}
+				// Create a stream and reset it.
+				// The server should ignore the stream.
+				st.writeHeaders(HeadersFrameParam{
+					StreamID:      3,
+					BlockFragment: st.encodeHeader(),
+					EndStream:     false,
+					EndHeaders:    true,
+				})
+				st.fr.WriteRSTStream(3, ErrCodeCancel)
+				// Create a stream and send data to it.
+				// The server should return flow control, even though it
+				// does not process the stream.
+				st.writeHeaders(HeadersFrameParam{
+					StreamID:      5,
+					BlockFragment: st.encodeHeader(),
+					EndStream:     false,
+					EndHeaders:    true,
+				})
+				// Write enough data to trigger a window update.
+				st.writeData(5, true, make([]byte, 1<<19))
 			case *HeadersFrame:
 				goth := st.decodeHeader(f.HeaderBlockFragment())
 				wanth := [][2]string{
@@ -4039,6 +4059,17 @@
 				if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 {
 					t.Errorf("unexpected DATA frame: %v", summarizeFrame(f))
 				}
+			case *WindowUpdateFrame:
+				if !sawGoAway {
+					t.Errorf("unexpected WINDOW_UPDATE frame: %v", summarizeFrame(f))
+					return
+				}
+				if f.StreamID != 0 {
+					st.t.Fatalf("WindowUpdate StreamID = %d; want 5", f.FrameHeader.StreamID)
+					return
+				}
+				sawWindowUpdate = true
+				unblockHandler <- true
 			default:
 				t.Logf("unexpected frame: %v", summarizeFrame(f))
 			}
@@ -4049,6 +4080,9 @@
 		if !sawRes {
 			t.Errorf("didn't see response")
 		}
+		if !sawWindowUpdate {
+			t.Errorf("didn't see WINDOW_UPDATE")
+		}
 	})
 }