http2: don't leak streams on broken body
Updates golang/go#27208
Change-Id: I5d9a643f33d27d33b24f670c98f5a51aa6000967
GitHub-Last-Rev: 3ac4a573b62846ef4944599085218e119819383c
GitHub-Pull-Request: golang/net#18
Reviewed-on: https://go-review.googlesource.com/c/132715
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 2c9fe88..3fe2918 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -1100,6 +1100,7 @@
default:
}
if err != nil {
+ cc.forgetStreamID(cs.ID)
return nil, cs.getStartedWrite(), err
}
bodyWritten = true
@@ -1221,6 +1222,7 @@
sawEOF = true
err = nil
} else if err != nil {
+ cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
return err
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index f61ff50..f6efa61 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -4180,3 +4180,99 @@
t.Fatalf("wrong kind %T; want *Transport", v.Interface())
}
}
+
+type errReader struct {
+ body []byte
+ err error
+}
+
+func (r *errReader) Read(p []byte) (int, error) {
+ if len(r.body) > 0 {
+ n := copy(p, r.body)
+ r.body = r.body[n:]
+ return n, nil
+ }
+ return 0, r.err
+}
+
+func testTransportBodyReadError(t *testing.T, body []byte) {
+ clientDone := make(chan struct{})
+ ct := newClientTester(t)
+ ct.client = func() error {
+ defer ct.cc.(*net.TCPConn).CloseWrite()
+ defer close(clientDone)
+
+ checkNoStreams := func() error {
+ cp, ok := ct.tr.connPool().(*clientConnPool)
+ if !ok {
+ return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
+ }
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ conns, ok := cp.conns["dummy.tld:443"]
+ if !ok {
+ return fmt.Errorf("missing connection")
+ }
+ if len(conns) != 1 {
+ return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
+ }
+ if activeStreams(conns[0]) != 0 {
+ return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
+ }
+ return nil
+ }
+ bodyReadError := errors.New("body read error")
+ body := &errReader{body, bodyReadError}
+ req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
+ if err != nil {
+ return err
+ }
+ _, err = ct.tr.RoundTrip(req)
+ if err != bodyReadError {
+ return fmt.Errorf("err = %v; want %v", err, bodyReadError)
+ }
+ if err = checkNoStreams(); err != nil {
+ return err
+ }
+ return nil
+ }
+ ct.server = func() error {
+ ct.greet()
+ var receivedBody []byte
+ var resetCount int
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it
+ // will have reported any
+ // errors on its side.
+ if bytes.Compare(receivedBody, body) != 0 {
+ return fmt.Errorf("body: %v; expected %v", receivedBody, body)
+ }
+ if resetCount != 1 {
+ return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
+ }
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame, *SettingsFrame:
+ case *HeadersFrame:
+ case *DataFrame:
+ receivedBody = append(receivedBody, f.Data()...)
+ case *RSTStreamFrame:
+ resetCount++
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+ ct.run()
+}
+
+func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
+func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }