http2: handle server errors after sending GOAWAY
The HTTP/2 server uses serverConn.goAwayCode to track whether a
connection has encountered a fatal error. If an error is encountered
after sending a ErrCodeNo GOAWAY, upgrade goAwayCode to reflect the
error status of the connection.
Fixes an issue where a server connection could hang forever waiting
for a clean shutdown that was preempted by a subsequent fatal error.
Fixes CVE-2022-27664
For golang/go#53977
Change-Id: I165b81ab53176c77a68c42976030499d57bb05d3
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1413887
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Roland Shoemaker <bracewell@google.com>
Reviewed-on: https://go-review.googlesource.com/c/net/+/428735
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/http2/server.go b/http2/server.go
index aa3b086..fd873b9 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -1371,6 +1371,9 @@
func (sc *serverConn) goAway(code ErrCode) {
sc.serveG.check()
if sc.inGoAway {
+ if sc.goAwayCode == ErrCodeNo {
+ sc.goAwayCode = code
+ }
return
}
sc.inGoAway = true
diff --git a/http2/server_test.go b/http2/server_test.go
index ddd3daf..5a54de2 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -4366,3 +4366,46 @@
}
})
}
+
+func TestProtocolErrorAfterGoAway(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ io.Copy(io.Discard, r.Body)
+ })
+ defer st.Close()
+
+ st.greet()
+ content := "some content"
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(
+ ":method", "POST",
+ "content-length", strconv.Itoa(len(content)),
+ ),
+ EndStream: false,
+ EndHeaders: true,
+ })
+ st.writeData(1, false, []byte(content[:5]))
+
+ _, err := st.readFrame()
+ if err != nil {
+ st.t.Fatal(err)
+ }
+
+ // Send a GOAWAY with ErrCodeNo, followed by a bogus window update.
+ // The server should close the connection.
+ if err := st.fr.WriteGoAway(1, ErrCodeNo, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
+ t.Fatal(err)
+ }
+
+ for {
+ if _, err := st.readFrame(); err != nil {
+ if err != io.EOF {
+ t.Errorf("unexpected readFrame error: %v", err)
+ }
+ break
+ }
+ }
+}