http2: shut down idle Transport connections after protocol errors
Change-Id: Ic4e85cdc75b4baef7cd61a65b1b09f430a2ffc4b
Reviewed-on: https://go-review.googlesource.com/c/net/+/352449
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/http2/transport.go b/http2/transport.go
index 0ab5ad3..a7f113b 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -868,6 +868,12 @@
cc.tconn.Close()
}
+func (cc *ClientConn) isDoNotReuseAndIdle() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.doNotReuse && len(cc.streams) == 0
+}
+
var shutdownEnterWaitStateHook = func() {}
// Shutdown gracefully close the client connection, waiting for running streams to complete.
@@ -2304,6 +2310,9 @@
func (rl *clientConnReadLoop) processData(f *DataFrame) error {
cc := rl.cc
cs := cc.streamByID(f.StreamID, f.StreamEnded())
+ if f.StreamEnded() && cc.isDoNotReuseAndIdle() {
+ rl.closeWhenIdle = true
+ }
data := f.Data()
if cs == nil {
cc.mu.Lock()
@@ -2554,11 +2563,15 @@
}
func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
- cs := rl.cc.streamByID(f.StreamID, true)
+ cc := rl.cc
+ cs := cc.streamByID(f.StreamID, true)
if cs == nil {
// TODO: return error if server tries to RST_STEAM an idle stream
return nil
}
+ if cc.isDoNotReuseAndIdle() {
+ rl.closeWhenIdle = true
+ }
select {
case <-cs.peerReset:
// Already reset.
@@ -2570,6 +2583,7 @@
if f.ErrCode == ErrCodeProtocol {
rl.cc.SetDoNotReuse()
serr.Cause = errFromPeer
+ rl.closeWhenIdle = true
}
if fn := cs.cc.t.CountError; fn != nil {
fn("recv_rststream_" + f.ErrCode.stringToken())
diff --git a/http2/transport_test.go b/http2/transport_test.go
index f1d5761..dd0860d 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5303,68 +5303,119 @@
lower: &clientConnPool{t: ct.tr},
}
ct.tr.ConnPool = pool
- done := make(chan struct{})
+
+ gotProtoError := make(chan bool, 1)
+ ct.tr.CountError = func(errType string) {
+ if errType == "recv_rststream_PROTOCOL_ERROR" {
+ select {
+ case gotProtoError <- true:
+ default:
+ }
+ }
+ }
ct.client = func() error {
- req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ // Start two requests. The first is a long request
+ // that will finish after the second. The second one
+ // will result in the protocol error. We check that
+ // after the first one closes, the connection then
+ // shuts down.
+
+ // The long, outer request.
+ req1, _ := http.NewRequest("GET", "https://dummy.tld/long", nil)
+ res1, err := ct.tr.RoundTrip(req1)
+ if err != nil {
+ return err
+ }
+ if got, want := res1.Header.Get("Is-Long"), "1"; got != want {
+ return fmt.Errorf("First response's Is-Long header = %q; want %q", got, want)
+ }
+
+ req, _ := http.NewRequest("POST", "https://dummy.tld/fails", nil)
res, err := ct.tr.RoundTrip(req)
const want = "only one dial allowed in test mode"
if got := fmt.Sprint(err); got != want {
t.Errorf("didn't dial again: got %#q; want %#q", got, want)
}
- close(done)
- ct.sc.Close()
if res != nil {
res.Body.Close()
}
+ select {
+ case <-gotProtoError:
+ default:
+ t.Errorf("didn't get stream protocol error")
+ }
+
+ if n, err := res1.Body.Read(make([]byte, 10)); err != io.EOF || n != 0 {
+ t.Errorf("unexpected body read %v, %v", n, err)
+ }
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.getErrs != 1 {
t.Errorf("pool get errors = %v; want 1", pool.getErrs)
}
- if len(pool.got) == 1 {
+ if len(pool.got) == 2 {
+ if pool.got[0] != pool.got[1] {
+ t.Errorf("requests went on different connections")
+ }
cc := pool.got[0]
cc.mu.Lock()
if !cc.doNotReuse {
t.Error("ClientConn not marked doNotReuse")
}
cc.mu.Unlock()
+
+ select {
+ case <-cc.readerDone:
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for reader to be done")
+ }
} else {
- t.Errorf("pool get success = %v; want 1", len(pool.got))
+ t.Errorf("pool get success = %v; want 2", len(pool.got))
}
return nil
}
ct.server = func() error {
ct.greet()
var sentErr bool
+ var numHeaders int
+ var firstStreamID uint32
+
+ var hbuf bytes.Buffer
+ enc := hpack.NewEncoder(&hbuf)
+
for {
f, err := ct.fr.ReadFrame()
+ if err == io.EOF {
+ // Client hung up on us, as it should at the end.
+ return nil
+ }
if err != nil {
- select {
- case <-done:
- return nil
- default:
- return err
- }
+ return err
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
+ numHeaders++
+ if numHeaders == 1 {
+ firstStreamID = f.StreamID
+ hbuf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+ enc.WriteField(hpack.HeaderField{Name: "is-long", Value: "1"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: f.StreamID,
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: hbuf.Bytes(),
+ })
+ continue
+ }
if !sentErr {
sentErr = true
ct.fr.WriteRSTStream(f.StreamID, ErrCodeProtocol)
+ ct.fr.WriteData(firstStreamID, true, nil)
continue
}
- var buf bytes.Buffer
- enc := hpack.NewEncoder(&buf)
- // send headers without Trailer header
- enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- ct.fr.WriteHeaders(HeadersFrameParam{
- StreamID: f.StreamID,
- EndHeaders: true,
- EndStream: true,
- BlockFragment: buf.Bytes(),
- })
}
}
return nil