http2: don't leak streams on broken body

Updates golang/go#27208

Change-Id: I5d9a643f33d27d33b24f670c98f5a51aa6000967
GitHub-Last-Rev: 3ac4a573b62846ef4944599085218e119819383c
GitHub-Pull-Request: golang/net#18
Reviewed-on: https://go-review.googlesource.com/c/132715
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 2c9fe88..3fe2918 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -1100,6 +1100,7 @@
 			default:
 			}
 			if err != nil {
+				cc.forgetStreamID(cs.ID)
 				return nil, cs.getStartedWrite(), err
 			}
 			bodyWritten = true
@@ -1221,6 +1222,7 @@
 			sawEOF = true
 			err = nil
 		} else if err != nil {
+			cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
 			return err
 		}
 
diff --git a/http2/transport_test.go b/http2/transport_test.go
index f61ff50..f6efa61 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -4180,3 +4180,99 @@
 		t.Fatalf("wrong kind %T; want *Transport", v.Interface())
 	}
 }
+
+type errReader struct {
+	body []byte
+	err  error
+}
+
+func (r *errReader) Read(p []byte) (int, error) {
+	if len(r.body) > 0 {
+		n := copy(p, r.body)
+		r.body = r.body[n:]
+		return n, nil
+	}
+	return 0, r.err
+}
+
+func testTransportBodyReadError(t *testing.T, body []byte) {
+	clientDone := make(chan struct{})
+	ct := newClientTester(t)
+	ct.client = func() error {
+		defer ct.cc.(*net.TCPConn).CloseWrite()
+		defer close(clientDone)
+
+		checkNoStreams := func() error {
+			cp, ok := ct.tr.connPool().(*clientConnPool)
+			if !ok {
+				return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool())
+			}
+			cp.mu.Lock()
+			defer cp.mu.Unlock()
+			conns, ok := cp.conns["dummy.tld:443"]
+			if !ok {
+				return fmt.Errorf("missing connection")
+			}
+			if len(conns) != 1 {
+				return fmt.Errorf("conn pool size: %v; expect 1", len(conns))
+			}
+			if activeStreams(conns[0]) != 0 {
+				return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0]))
+			}
+			return nil
+		}
+		bodyReadError := errors.New("body read error")
+		body := &errReader{body, bodyReadError}
+		req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
+		if err != nil {
+			return err
+		}
+		_, err = ct.tr.RoundTrip(req)
+		if err != bodyReadError {
+			return fmt.Errorf("err = %v; want %v", err, bodyReadError)
+		}
+		if err = checkNoStreams(); err != nil {
+			return err
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		var receivedBody []byte
+		var resetCount int
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				select {
+				case <-clientDone:
+					// If the client's done, it
+					// will have reported any
+					// errors on its side.
+					if bytes.Compare(receivedBody, body) != 0 {
+						return fmt.Errorf("body: %v; expected %v", receivedBody, body)
+					}
+					if resetCount != 1 {
+						return fmt.Errorf("stream reset count: %v; expected: 1", resetCount)
+					}
+					return nil
+				default:
+					return err
+				}
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *HeadersFrame:
+			case *DataFrame:
+				receivedBody = append(receivedBody, f.Data()...)
+			case *RSTStreamFrame:
+				resetCount++
+			default:
+				return fmt.Errorf("Unexpected client frame %v", f)
+			}
+		}
+	}
+	ct.run()
+}
+
+func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
+func TestTransportBodyReadError_Some(t *testing.T)        { testTransportBodyReadError(t, []byte("123")) }