http2: rewrite inbound flow control tracking

Add a new inflow type for tracking inbound flow control.
An inflow tracks both the window sent to the peer, and the
window we are willing to send. Updates are accumulated and
sent in a batch when the unsent window update is large
enough.

This change makes both the client and server use the same
algorithm to decide when to send window updates. This should
slightly reduce the rate of updates sent by the client, and
significantly reduce the rate sent by the server.

Fix a client flow control tracking bug: When processing data
for a canceled stream, the record of flow control consumed
by the peer was not updated to account for the discard
stream.

Fixes golang/go#28732
Fixes golang/go#56558

Change-Id: Id119d17b84b46f3dc2719f28a86758d9a10085d9
Reviewed-on: https://go-review.googlesource.com/c/net/+/448155
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
diff --git a/http2/flow.go b/http2/flow.go
index b51f0e0..750ac52 100644
--- a/http2/flow.go
+++ b/http2/flow.go
@@ -6,23 +6,91 @@
 
 package http2
 
-// flow is the flow control window's size.
-type flow struct {
+// inflowMinRefresh is the minimum number of bytes we'll send for a
+// flow control window update.
+const inflowMinRefresh = 4 << 10
+
+// inflow accounts for an inbound flow control window.
+// It tracks both the latest window sent to the peer (used for enforcement)
+// and the accumulated unsent window.
+type inflow struct {
+	avail  int32
+	unsent int32
+}
+
+// set sets the initial window.
+func (f *inflow) init(n int32) {
+	f.avail = n
+}
+
+// add adds n bytes to the window, with a maximum window size of max,
+// indicating that the peer can now send us more data.
+// For example, the user read from a {Request,Response} body and consumed
+// some of the buffered data, so the peer can now send more.
+// It returns the number of bytes to send in a WINDOW_UPDATE frame to the peer.
+// Window updates are accumulated and sent when the unsent capacity
+// is at least inflowMinRefresh or will at least double the peer's available window.
+func (f *inflow) add(n int) (connAdd int32) {
+	if n < 0 {
+		panic("negative update")
+	}
+	unsent := int64(f.unsent) + int64(n)
+	// "A sender MUST NOT allow a flow-control window to exceed 2^31-1 octets."
+	// RFC 7540 Section 6.9.1.
+	const maxWindow = 1<<31 - 1
+	if unsent+int64(f.avail) > maxWindow {
+		panic("flow control update exceeds maximum window size")
+	}
+	f.unsent = int32(unsent)
+	if f.unsent < inflowMinRefresh && f.unsent < f.avail {
+		// If there aren't at least inflowMinRefresh bytes of window to send,
+		// and this update won't at least double the window, buffer the update for later.
+		return 0
+	}
+	f.avail += f.unsent
+	f.unsent = 0
+	return int32(unsent)
+}
+
+// take attempts to take n bytes from the peer's flow control window.
+// It reports whether the window has available capacity.
+func (f *inflow) take(n uint32) bool {
+	if n > uint32(f.avail) {
+		return false
+	}
+	f.avail -= int32(n)
+	return true
+}
+
+// takeInflows attempts to take n bytes from two inflows,
+// typically connection-level and stream-level flows.
+// It reports whether both windows have available capacity.
+func takeInflows(f1, f2 *inflow, n uint32) bool {
+	if n > uint32(f1.avail) || n > uint32(f2.avail) {
+		return false
+	}
+	f1.avail -= int32(n)
+	f2.avail -= int32(n)
+	return true
+}
+
+// outflow is the outbound flow control window's size.
+type outflow struct {
 	_ incomparable
 
 	// n is the number of DATA bytes we're allowed to send.
-	// A flow is kept both on a conn and a per-stream.
+	// An outflow is kept both on a conn and a per-stream.
 	n int32
 
-	// conn points to the shared connection-level flow that is
-	// shared by all streams on that conn. It is nil for the flow
+	// conn points to the shared connection-level outflow that is
+	// shared by all streams on that conn. It is nil for the outflow
 	// that's on the conn directly.
-	conn *flow
+	conn *outflow
 }
 
-func (f *flow) setConnFlow(cf *flow) { f.conn = cf }
+func (f *outflow) setConnFlow(cf *outflow) { f.conn = cf }
 
-func (f *flow) available() int32 {
+func (f *outflow) available() int32 {
 	n := f.n
 	if f.conn != nil && f.conn.n < n {
 		n = f.conn.n
@@ -30,7 +98,7 @@
 	return n
 }
 
-func (f *flow) take(n int32) {
+func (f *outflow) take(n int32) {
 	if n > f.available() {
 		panic("internal error: took too much")
 	}
@@ -42,7 +110,7 @@
 
 // add adds n bytes (positive or negative) to the flow control window.
 // It returns false if the sum would exceed 2^31-1.
-func (f *flow) add(n int32) bool {
+func (f *outflow) add(n int32) bool {
 	sum := f.n + n
 	if (sum > n) == (f.n > 0) {
 		f.n = sum
diff --git a/http2/flow_test.go b/http2/flow_test.go
index 7ae82c7..cae4f38 100644
--- a/http2/flow_test.go
+++ b/http2/flow_test.go
@@ -6,9 +6,61 @@
 
 import "testing"
 
-func TestFlow(t *testing.T) {
-	var st flow
-	var conn flow
+func TestInFlowTake(t *testing.T) {
+	var f inflow
+	f.init(100)
+	if !f.take(40) {
+		t.Fatalf("f.take(40) from 100: got false, want true")
+	}
+	if !f.take(40) {
+		t.Fatalf("f.take(40) from 60: got false, want true")
+	}
+	if f.take(40) {
+		t.Fatalf("f.take(40) from 20: got true, want false")
+	}
+	if !f.take(20) {
+		t.Fatalf("f.take(20) from 20: got false, want true")
+	}
+}
+
+func TestInflowAddSmall(t *testing.T) {
+	var f inflow
+	f.init(0)
+	// Adding even a small amount when there is no flow causes an immediate send.
+	if got, want := f.add(1), int32(1); got != want {
+		t.Fatalf("f.add(1) to 1 = %v, want %v", got, want)
+	}
+}
+
+func TestInflowAdd(t *testing.T) {
+	var f inflow
+	f.init(10 * inflowMinRefresh)
+	if got, want := f.add(inflowMinRefresh-1), int32(0); got != want {
+		t.Fatalf("f.add(minRefresh - 1) = %v, want %v", got, want)
+	}
+	if got, want := f.add(1), int32(inflowMinRefresh); got != want {
+		t.Fatalf("f.add(minRefresh) = %v, want %v", got, want)
+	}
+}
+
+func TestTakeInflows(t *testing.T) {
+	var a, b inflow
+	a.init(10)
+	b.init(20)
+	if !takeInflows(&a, &b, 5) {
+		t.Fatalf("takeInflows(a, b, 5) from 10, 20: got false, want true")
+	}
+	if takeInflows(&a, &b, 6) {
+		t.Fatalf("takeInflows(a, b, 6) from 5, 15: got true, want false")
+	}
+	if !takeInflows(&a, &b, 5) {
+		t.Fatalf("takeInflows(a, b, 5) from 5, 15: got false, want true")
+	}
+}
+
+func TestOutFlow(t *testing.T) {
+	var st outflow
+	var conn outflow
 	st.add(3)
 	conn.add(2)
 
@@ -29,8 +81,8 @@
 	}
 }
 
-func TestFlowAdd(t *testing.T) {
-	var f flow
+func TestOutFlowAdd(t *testing.T) {
+	var f outflow
 	if !f.add(1) {
 		t.Fatal("failed to add 1")
 	}
@@ -51,8 +103,8 @@
 	}
 }
 
-func TestFlowAddOverflow(t *testing.T) {
-	var f flow
+func TestOutFlowAddOverflow(t *testing.T) {
+	var f outflow
 	if !f.add(0) {
 		t.Fatal("failed to add 0")
 	}
diff --git a/http2/server.go b/http2/server.go
index 4eb7617..b624dc0 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -448,7 +448,7 @@
 	// configured value for inflow, that will be updated when we send a
 	// WINDOW_UPDATE shortly after sending SETTINGS.
 	sc.flow.add(initialWindowSize)
-	sc.inflow.add(initialWindowSize)
+	sc.inflow.init(initialWindowSize)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
 	sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
 
@@ -563,8 +563,8 @@
 	wroteFrameCh     chan frameWriteResult  // from writeFrameAsync -> serve, tickles more frame writes
 	bodyReadCh       chan bodyReadMsg       // from handlers -> serve
 	serveMsgCh       chan interface{}       // misc messages & code to send to / run on the serve loop
-	flow             flow                   // conn-wide (not stream-specific) outbound flow control
-	inflow           flow                   // conn-wide inbound flow control
+	flow             outflow                // conn-wide (not stream-specific) outbound flow control
+	inflow           inflow                 // conn-wide inbound flow control
 	tlsState         *tls.ConnectionState   // shared by all handlers, like net/http
 	remoteAddrStr    string
 	writeSched       WriteScheduler
@@ -641,10 +641,10 @@
 	cancelCtx func()
 
 	// owned by serverConn's serve loop:
-	bodyBytes        int64 // body bytes seen so far
-	declBodyBytes    int64 // or -1 if undeclared
-	flow             flow  // limits writing from Handler to client
-	inflow           flow  // what the client is allowed to POST/etc to us
+	bodyBytes        int64   // body bytes seen so far
+	declBodyBytes    int64   // or -1 if undeclared
+	flow             outflow // limits writing from Handler to client
+	inflow           inflow  // what the client is allowed to POST/etc to us
 	state            streamState
 	resetQueued      bool        // RST_STREAM queued for write; set by sc.resetStream
 	gotTrailerHeader bool        // HEADER frame for trailers was seen
@@ -1503,7 +1503,7 @@
 	if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
 
 		if f, ok := f.(*DataFrame); ok {
-			if sc.inflow.available() < int32(f.Length) {
+			if !sc.inflow.take(f.Length) {
 				return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
 			}
 			sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
@@ -1775,14 +1775,9 @@
 		// But still enforce their connection-level flow control,
 		// and return any flow control bytes since we're not going
 		// to consume them.
-		if sc.inflow.available() < int32(f.Length) {
+		if !sc.inflow.take(f.Length) {
 			return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
 		}
-		// Deduct the flow control from inflow, since we're
-		// going to immediately add it back in
-		// sendWindowUpdate, which also schedules sending the
-		// frames.
-		sc.inflow.take(int32(f.Length))
 		sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
 
 		if st != nil && st.resetQueued {
@@ -1797,10 +1792,9 @@
 
 	// Sender sending more than they'd declared?
 	if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
-		if sc.inflow.available() < int32(f.Length) {
+		if !sc.inflow.take(f.Length) {
 			return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
 		}
-		sc.inflow.take(int32(f.Length))
 		sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
 
 		st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
@@ -1811,10 +1805,9 @@
 	}
 	if f.Length > 0 {
 		// Check whether the client has flow control quota.
-		if st.inflow.available() < int32(f.Length) {
+		if !takeInflows(&sc.inflow, &st.inflow, f.Length) {
 			return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl))
 		}
-		st.inflow.take(int32(f.Length))
 
 		if len(data) > 0 {
 			wrote, err := st.body.Write(data)
@@ -1830,10 +1823,12 @@
 
 		// Return any padded flow control now, since we won't
 		// refund it later on body reads.
-		if pad := int32(f.Length) - int32(len(data)); pad > 0 {
-			sc.sendWindowUpdate32(nil, pad)
-			sc.sendWindowUpdate32(st, pad)
-		}
+		// Call sendWindowUpdate even if there is no padding,
+		// to return buffered flow control credit if the sent
+		// window has shrunk.
+		pad := int32(f.Length) - int32(len(data))
+		sc.sendWindowUpdate32(nil, pad)
+		sc.sendWindowUpdate32(st, pad)
 	}
 	if f.StreamEnded() {
 		st.endStream()
@@ -2105,8 +2100,7 @@
 	st.cw.Init()
 	st.flow.conn = &sc.flow // link to conn-level counter
 	st.flow.add(sc.initialStreamSendWindowSize)
-	st.inflow.conn = &sc.inflow // link to conn-level counter
-	st.inflow.add(sc.srv.initialStreamRecvWindowSize())
+	st.inflow.init(sc.srv.initialStreamRecvWindowSize())
 	if sc.hs.WriteTimeout != 0 {
 		st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
 	}
@@ -2388,47 +2382,28 @@
 }
 
 // st may be nil for conn-level
-func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
-	sc.serveG.check()
-	// "The legal range for the increment to the flow control
-	// window is 1 to 2^31-1 (2,147,483,647) octets."
-	// A Go Read call on 64-bit machines could in theory read
-	// a larger Read than this. Very unlikely, but we handle it here
-	// rather than elsewhere for now.
-	const maxUint31 = 1<<31 - 1
-	for n > maxUint31 {
-		sc.sendWindowUpdate32(st, maxUint31)
-		n -= maxUint31
-	}
-	sc.sendWindowUpdate32(st, int32(n))
+func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) {
+	sc.sendWindowUpdate(st, int(n))
 }
 
 // st may be nil for conn-level
-func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) {
+func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
 	sc.serveG.check()
-	if n == 0 {
+	var streamID uint32
+	var send int32
+	if st == nil {
+		send = sc.inflow.add(n)
+	} else {
+		streamID = st.id
+		send = st.inflow.add(n)
+	}
+	if send == 0 {
 		return
 	}
-	if n < 0 {
-		panic("negative update")
-	}
-	var streamID uint32
-	if st != nil {
-		streamID = st.id
-	}
 	sc.writeFrame(FrameWriteRequest{
-		write:  writeWindowUpdate{streamID: streamID, n: uint32(n)},
+		write:  writeWindowUpdate{streamID: streamID, n: uint32(send)},
 		stream: st,
 	})
-	var ok bool
-	if st == nil {
-		ok = sc.inflow.add(n)
-	} else {
-		ok = st.inflow.add(n)
-	}
-	if !ok {
-		panic("internal error; sent too many window updates without decrements?")
-	}
 }
 
 // requestBody is the Handler's Request.Body type.
diff --git a/http2/server_test.go b/http2/server_test.go
index 815efe1..178c28b 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -482,6 +482,22 @@
 	}
 }
 
+// writeReadPing sends a PING and immediately reads the PING ACK.
+// It will fail if any other unread data was pending on the connection.
+func (st *serverTester) writeReadPing() {
+	data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+	if err := st.fr.WritePing(false, data); err != nil {
+		st.t.Fatalf("Error writing PING: %v", err)
+	}
+	p := st.wantPing()
+	if p.Flags&FlagPingAck == 0 {
+		st.t.Fatalf("got a PING, want a PING ACK")
+	}
+	if p.Data != data {
+		st.t.Fatalf("got PING data = %x, want %x", p.Data, data)
+	}
+}
+
 func (st *serverTester) readFrame() (Frame, error) {
 	return st.fr.ReadFrame()
 }
@@ -592,6 +608,28 @@
 	}
 }
 
+func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) {
+	var initial int32
+	if streamID == 0 {
+		initial = st.sc.srv.initialConnRecvWindowSize()
+	} else {
+		initial = st.sc.srv.initialStreamRecvWindowSize()
+	}
+	donec := make(chan struct{})
+	st.sc.sendServeMsg(func(sc *serverConn) {
+		defer close(donec)
+		var avail int32
+		if streamID == 0 {
+			avail = sc.inflow.avail + sc.inflow.unsent
+		} else {
+		}
+		if got, want := initial-avail, consumed; got != want {
+			st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want)
+		}
+	})
+	<-donec
+}
+
 func (st *serverTester) wantSettingsAck() {
 	f, err := st.readFrame()
 	if err != nil {
@@ -811,7 +849,8 @@
 			st.writeData(1, true, []byte("12345"))
 			// Return flow control bytes back, since the data handler closed
 			// the stream.
-			st.wantWindowUpdate(0, 5)
+			st.wantRSTStream(1, ErrCodeProtocol)
+			st.wantFlowControlConsumed(0, 0)
 		})
 }
 
@@ -1238,69 +1277,89 @@
 }
 
 func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
+	// Need to set this to at least twice the initial window size,
+	// or st.greet gets stuck waiting for a WINDOW_UPDATE.
+	//
+	// This also needs to be less than MAX_FRAME_SIZE.
+	const windowSize = 65535 * 2
 	puppet := newHandlerPuppet()
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		puppet.act(w, r)
+	}, func(s *Server) {
+		s.MaxUploadBufferPerConnection = windowSize
+		s.MaxUploadBufferPerStream = windowSize
 	})
 	defer st.Close()
 	defer puppet.done()
 
 	st.greet()
-
 	st.writeHeaders(HeadersFrameParam{
 		StreamID:      1, // clients send odd numbers
 		BlockFragment: st.encodeHeader(":method", "POST"),
 		EndStream:     false, // data coming
 		EndHeaders:    true,
 	})
-	st.writeData(1, false, []byte("abcdef"))
-	puppet.do(readBodyHandler(t, "abc"))
-	st.wantWindowUpdate(0, 3)
-	st.wantWindowUpdate(1, 3)
+	st.writeReadPing()
 
-	puppet.do(readBodyHandler(t, "def"))
-	st.wantWindowUpdate(0, 3)
-	st.wantWindowUpdate(1, 3)
+	// Write less than half the max window of data and consume it.
+	// The server doesn't return flow control yet, buffering the 1024 bytes to
+	// combine with a future update.
+	data := make([]byte, windowSize)
+	st.writeData(1, false, data[:1024])
+	puppet.do(readBodyHandler(t, string(data[:1024])))
+	st.writeReadPing()
 
-	st.writeData(1, true, []byte("ghijkl")) // END_STREAM here
-	puppet.do(readBodyHandler(t, "ghi"))
-	puppet.do(readBodyHandler(t, "jkl"))
-	st.wantWindowUpdate(0, 3)
-	st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM
+	// Write up to the window limit.
+	// The server returns the buffered credit.
+	st.writeData(1, false, data[1024:])
+	st.wantWindowUpdate(0, 1024)
+	st.wantWindowUpdate(1, 1024)
+	st.writeReadPing()
+
+	// The handler consumes the data and the server returns credit.
+	puppet.do(readBodyHandler(t, string(data[1024:])))
+	st.wantWindowUpdate(0, windowSize-1024)
+	st.wantWindowUpdate(1, windowSize-1024)
+	st.writeReadPing()
 }
 
 // the version of the TestServer_Handler_Sends_WindowUpdate with padding.
 // See golang.org/issue/16556
 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
+	const windowSize = 65535 * 2
 	puppet := newHandlerPuppet()
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		puppet.act(w, r)
+	}, func(s *Server) {
+		s.MaxUploadBufferPerConnection = windowSize
+		s.MaxUploadBufferPerStream = windowSize
 	})
 	defer st.Close()
 	defer puppet.done()
 
 	st.greet()
-
 	st.writeHeaders(HeadersFrameParam{
 		StreamID:      1,
 		BlockFragment: st.encodeHeader(":method", "POST"),
 		EndStream:     false,
 		EndHeaders:    true,
 	})
-	st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0})
+	st.writeReadPing()
 
-	// Expect to immediately get our 5 bytes of padding back for
-	// both the connection and stream (4 bytes of padding + 1 byte of length)
-	st.wantWindowUpdate(0, 5)
-	st.wantWindowUpdate(1, 5)
+	// Write half a window of data, with some padding.
+	// The server doesn't return the padding yet, buffering the 5 bytes to combine
+	// with a future update.
+	data := make([]byte, windowSize/2)
+	pad := make([]byte, 4)
+	st.writeDataPadded(1, false, data, pad)
+	st.writeReadPing()
 
-	puppet.do(readBodyHandler(t, "abc"))
-	st.wantWindowUpdate(0, 3)
-	st.wantWindowUpdate(1, 3)
-
-	puppet.do(readBodyHandler(t, "def"))
-	st.wantWindowUpdate(0, 3)
-	st.wantWindowUpdate(1, 3)
+	// The handler consumes the body.
+	// The server returns flow control for the body and padding
+	// (4 bytes of padding + 1 byte of length).
+	puppet.do(readBodyHandler(t, string(data)))
+	st.wantWindowUpdate(0, uint32(len(data)+1+len(pad)))
+	st.wantWindowUpdate(1, uint32(len(data)+1+len(pad)))
 }
 
 func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
@@ -2296,8 +2355,6 @@
 		// gigantic and/or sensitive "foo" payload now.
 		st.writeData(1, true, []byte(msg))
 
-		st.wantWindowUpdate(0, uint32(len(msg)))
-
 		hf = st.wantHeaders()
 		if hf.StreamEnded() {
 			t.Fatal("expected data to follow")
@@ -2485,15 +2542,16 @@
 		// it did before.
 		st.writeData(1, true, []byte("foo"))
 
-		// Get our flow control bytes back, since the handler didn't get them.
-		st.wantWindowUpdate(0, uint32(len("foo")))
-
 		// Sent after a peer sends data anyway (admittedly the
 		// previous RST_STREAM might've still been in-flight),
 		// but they'll get the more friendly 'cancel' code
 		// first.
 		st.wantRSTStream(1, ErrCodeStreamClosed)
 
+		// We should have our flow control bytes back,
+		// since the handler didn't get them.
+		st.wantFlowControlConsumed(0, 0)
+
 		// Set up a bunch of machinery to record the panic we saw
 		// previously.
 		var (
@@ -3967,8 +4025,8 @@
 			EndHeaders: true,
 		})
 		st.writeData(1, true, []byte("12345"))
-		st.wantWindowUpdate(0, 5)
 		st.wantRSTStream(1, ErrCodeProtocol)
+		st.wantFlowControlConsumed(0, 0)
 	})
 }
 
@@ -4258,7 +4316,8 @@
 }
 
 func TestServerWindowUpdateOnBodyClose(t *testing.T) {
-	const content = "12345678"
+	const windowSize = 65535 * 2
+	content := make([]byte, windowSize)
 	blockCh := make(chan bool)
 	errc := make(chan error, 1)
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
@@ -4275,6 +4334,9 @@
 		blockCh <- true
 		<-blockCh
 		errc <- nil
+	}, func(s *Server) {
+		s.MaxUploadBufferPerConnection = windowSize
+		s.MaxUploadBufferPerStream = windowSize
 	})
 	defer st.Close()
 
@@ -4288,13 +4350,13 @@
 		EndStream:  false, // to say DATA frames are coming
 		EndHeaders: true,
 	})
-	st.writeData(1, false, []byte(content[:5]))
+	st.writeData(1, false, content[:windowSize/2])
 	<-blockCh
 	st.stream(1).body.CloseWithError(io.EOF)
-	st.writeData(1, false, []byte(content[5:]))
 	blockCh <- true
 
-	increments := len(content)
+	// Wait for flow control credit for the portion of the request written so far.
+	increments := windowSize / 2
 	for {
 		f, err := st.readFrame()
 		if err == io.EOF {
@@ -4311,6 +4373,10 @@
 		}
 	}
 
+	// Writing data after the stream is reset immediately returns flow control credit.
+	st.writeData(1, false, content[windowSize/2:])
+	st.wantWindowUpdate(0, windowSize/2)
+
 	if err := <-errc; err != nil {
 		t.Error(err)
 	}
@@ -4465,11 +4531,7 @@
 		EndHeaders: true,
 	})
 	st.writeData(1, false, []byte(content[:5]))
-
-	_, err := st.readFrame()
-	if err != nil {
-		st.t.Fatal(err)
-	}
+	st.writeReadPing()
 
 	// Send a GOAWAY with ErrCodeNo, followed by a bogus window update.
 	// The server should close the connection.
diff --git a/http2/transport.go b/http2/transport.go
index 30f706e..b43ec10 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -47,10 +47,6 @@
 	// we buffer per stream.
 	transportDefaultStreamFlow = 4 << 20
 
-	// transportDefaultStreamMinRefresh is the minimum number of bytes we'll send
-	// a stream-level WINDOW_UPDATE for at a time.
-	transportDefaultStreamMinRefresh = 4 << 10
-
 	defaultUserAgent = "Go-http-client/2.0"
 
 	// initialMaxConcurrentStreams is a connections maxConcurrentStreams until
@@ -310,8 +306,8 @@
 
 	mu              sync.Mutex // guards following
 	cond            *sync.Cond // hold mu; broadcast on flow/closed changes
-	flow            flow       // our conn-level flow control quota (cs.flow is per stream)
-	inflow          flow       // peer's conn-level flow control
+	flow            outflow    // our conn-level flow control quota (cs.outflow is per stream)
+	inflow          inflow     // peer's conn-level flow control
 	doNotReuse      bool       // whether conn is marked to not be reused for any future requests
 	closing         bool
 	closed          bool
@@ -376,10 +372,10 @@
 	respHeaderRecv chan struct{}  // closed when headers are received
 	res            *http.Response // set if respHeaderRecv is closed
 
-	flow        flow  // guarded by cc.mu
-	inflow      flow  // guarded by cc.mu
-	bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
-	readErr     error // sticky read error; owned by transportResponseBody.Read
+	flow        outflow // guarded by cc.mu
+	inflow      inflow  // guarded by cc.mu
+	bytesRemain int64   // -1 means unknown; owned by transportResponseBody.Read
+	readErr     error   // sticky read error; owned by transportResponseBody.Read
 
 	reqBody              io.ReadCloser
 	reqBodyContentLength int64         // -1 means unknown
@@ -811,7 +807,7 @@
 	cc.bw.Write(clientPreface)
 	cc.fr.WriteSettings(initialSettings...)
 	cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow)
-	cc.inflow.add(transportDefaultConnFlow + initialWindowSize)
+	cc.inflow.init(transportDefaultConnFlow + initialWindowSize)
 	cc.bw.Flush()
 	if cc.werr != nil {
 		cc.Close()
@@ -2073,8 +2069,7 @@
 func (cc *ClientConn) addStreamLocked(cs *clientStream) {
 	cs.flow.add(int32(cc.initialWindowSize))
 	cs.flow.setConnFlow(&cc.flow)
-	cs.inflow.add(transportDefaultStreamFlow)
-	cs.inflow.setConnFlow(&cc.inflow)
+	cs.inflow.init(transportDefaultStreamFlow)
 	cs.ID = cc.nextStreamID
 	cc.nextStreamID += 2
 	cc.streams[cs.ID] = cs
@@ -2533,21 +2528,10 @@
 	}
 
 	cc.mu.Lock()
-	var connAdd, streamAdd int32
-	// Check the conn-level first, before the stream-level.
-	if v := cc.inflow.available(); v < transportDefaultConnFlow/2 {
-		connAdd = transportDefaultConnFlow - v
-		cc.inflow.add(connAdd)
-	}
+	connAdd := cc.inflow.add(n)
+	var streamAdd int32
 	if err == nil { // No need to refresh if the stream is over or failed.
-		// Consider any buffered body data (read from the conn but not
-		// consumed by the client) when computing flow control for this
-		// stream.
-		v := int(cs.inflow.available()) + cs.bufPipe.Len()
-		if v < transportDefaultStreamFlow-transportDefaultStreamMinRefresh {
-			streamAdd = int32(transportDefaultStreamFlow - v)
-			cs.inflow.add(streamAdd)
-		}
+		streamAdd = cs.inflow.add(n)
 	}
 	cc.mu.Unlock()
 
@@ -2575,17 +2559,15 @@
 	if unread > 0 {
 		cc.mu.Lock()
 		// Return connection-level flow control.
-		if unread > 0 {
-			cc.inflow.add(int32(unread))
-		}
+		connAdd := cc.inflow.add(unread)
 		cc.mu.Unlock()
 
 		// TODO(dneil): Acquiring this mutex can block indefinitely.
 		// Move flow control return to a goroutine?
 		cc.wmu.Lock()
 		// Return connection-level flow control.
-		if unread > 0 {
-			cc.fr.WriteWindowUpdate(0, uint32(unread))
+		if connAdd > 0 {
+			cc.fr.WriteWindowUpdate(0, uint32(connAdd))
 		}
 		cc.bw.Flush()
 		cc.wmu.Unlock()
@@ -2628,13 +2610,18 @@
 		// But at least return their flow control:
 		if f.Length > 0 {
 			cc.mu.Lock()
-			cc.inflow.add(int32(f.Length))
+			ok := cc.inflow.take(f.Length)
+			connAdd := cc.inflow.add(int(f.Length))
 			cc.mu.Unlock()
-
-			cc.wmu.Lock()
-			cc.fr.WriteWindowUpdate(0, uint32(f.Length))
-			cc.bw.Flush()
-			cc.wmu.Unlock()
+			if !ok {
+				return ConnectionError(ErrCodeFlowControl)
+			}
+			if connAdd > 0 {
+				cc.wmu.Lock()
+				cc.fr.WriteWindowUpdate(0, uint32(connAdd))
+				cc.bw.Flush()
+				cc.wmu.Unlock()
+			}
 		}
 		return nil
 	}
@@ -2665,9 +2652,7 @@
 		}
 		// Check connection-level flow control.
 		cc.mu.Lock()
-		if cs.inflow.available() >= int32(f.Length) {
-			cs.inflow.take(int32(f.Length))
-		} else {
+		if !takeInflows(&cc.inflow, &cs.inflow, f.Length) {
 			cc.mu.Unlock()
 			return ConnectionError(ErrCodeFlowControl)
 		}
@@ -2689,19 +2674,20 @@
 			}
 		}
 
-		if refund > 0 {
-			cc.inflow.add(int32(refund))
-			if !didReset {
-				cs.inflow.add(int32(refund))
-			}
+		sendConn := cc.inflow.add(refund)
+		var sendStream int32
+		if !didReset {
+			sendStream = cs.inflow.add(refund)
 		}
 		cc.mu.Unlock()
 
-		if refund > 0 {
+		if sendConn > 0 || sendStream > 0 {
 			cc.wmu.Lock()
-			cc.fr.WriteWindowUpdate(0, uint32(refund))
-			if !didReset {
-				cc.fr.WriteWindowUpdate(cs.ID, uint32(refund))
+			if sendConn > 0 {
+				cc.fr.WriteWindowUpdate(0, uint32(sendConn))
+			}
+			if sendStream > 0 {
+				cc.fr.WriteWindowUpdate(cs.ID, uint32(sendStream))
 			}
 			cc.bw.Flush()
 			cc.wmu.Unlock()
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 00776ad..5adef42 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -835,6 +835,55 @@
 	}
 }
 
+// writeReadPing sends a PING and immediately reads the PING ACK.
+// It will fail if any other unread data was pending on the connection,
+// aside from SETTINGS frames.
+func (ct *clientTester) writeReadPing() error {
+	data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+	if err := ct.fr.WritePing(false, data); err != nil {
+		return fmt.Errorf("Error writing PING: %v", err)
+	}
+	f, err := ct.readNonSettingsFrame()
+	if err != nil {
+		return err
+	}
+	p, ok := f.(*PingFrame)
+	if !ok {
+		return fmt.Errorf("got a %v, want a PING ACK", f)
+	}
+	if p.Flags&FlagPingAck == 0 {
+		return fmt.Errorf("got a PING, want a PING ACK")
+	}
+	if p.Data != data {
+		return fmt.Errorf("got PING data = %x, want %x", p.Data, data)
+	}
+	return nil
+}
+
+func (ct *clientTester) inflowWindow(streamID uint32) int32 {
+	pool := ct.tr.connPoolOrDef.(*clientConnPool)
+	pool.mu.Lock()
+	defer pool.mu.Unlock()
+	if n := len(pool.keys); n != 1 {
+		ct.t.Errorf("clientConnPool contains %v keys, expected 1", n)
+		return -1
+	}
+	for cc := range pool.keys {
+		cc.mu.Lock()
+		defer cc.mu.Unlock()
+		if streamID == 0 {
+			return cc.inflow.avail + cc.inflow.unsent
+		}
+		cs := cc.streams[streamID]
+		if cs == nil {
+			ct.t.Errorf("no stream with id %v", streamID)
+			return -1
+		}
+		return cs.inflow.avail + cs.inflow.unsent
+	}
+	return -1
+}
+
 func (ct *clientTester) cleanup() {
 	ct.tr.CloseIdleConnections()
 
@@ -2905,22 +2954,17 @@
 func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
 	ct := newClientTester(t)
 
-	clientClosed := make(chan struct{})
-	serverWroteFirstByte := make(chan struct{})
-
 	ct.client = func() error {
 		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
 		res, err := ct.tr.RoundTrip(req)
 		if err != nil {
 			return err
 		}
-		<-serverWroteFirstByte
 
 		if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
 			return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
 		}
 		res.Body.Close() // leaving 4999 bytes unread
-		close(clientClosed)
 
 		return nil
 	}
@@ -2955,6 +2999,7 @@
 			EndStream:     false,
 			BlockFragment: buf.Bytes(),
 		})
+		initialInflow := ct.inflowWindow(0)
 
 		// Two cases:
 		// - Send one DATA frame with 5000 bytes.
@@ -2963,50 +3008,63 @@
 		// In both cases, the client should consume one byte of data,
 		// refund that byte, then refund the following 4999 bytes.
 		//
-		// In the second case, the server waits for the client connection to
-		// close before seconding the second DATA frame. This tests the case
+		// In the second case, the server waits for the client to reset the
+		// stream before sending the second DATA frame. This tests the case
 		// where the client receives a DATA frame after it has reset the stream.
 		if oneDataFrame {
 			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
-			close(serverWroteFirstByte)
-			<-clientClosed
 		} else {
 			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
-			close(serverWroteFirstByte)
-			<-clientClosed
-			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
 		}
 
-		waitingFor := "RSTStreamFrame"
-		sawRST := false
-		sawWUF := false
-		for !sawRST && !sawWUF {
-			f, err := ct.fr.ReadFrame()
+		wantRST := true
+		wantWUF := true
+		if !oneDataFrame {
+			wantWUF = false // flow control update is small, and will not be sent
+		}
+		for wantRST || wantWUF {
+			f, err := ct.readNonSettingsFrame()
 			if err != nil {
-				return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err)
+				return err
 			}
 			switch f := f.(type) {
-			case *SettingsFrame:
 			case *RSTStreamFrame:
-				if sawRST {
-					return fmt.Errorf("saw second RSTStreamFrame: %v", summarizeFrame(f))
+				if !wantRST {
+					return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
 				}
 				if f.ErrCode != ErrCodeCancel {
 					return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
 				}
-				sawRST = true
+				wantRST = false
 			case *WindowUpdateFrame:
-				if sawWUF {
-					return fmt.Errorf("saw second WindowUpdateFrame: %v", summarizeFrame(f))
+				if !wantWUF {
+					return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
 				}
-				if f.Increment != 4999 {
+				if f.Increment != 5000 {
 					return fmt.Errorf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
 				}
-				sawWUF = true
+				wantWUF = false
 			default:
 				return fmt.Errorf("Unexpected frame: %v", summarizeFrame(f))
 			}
 		}
+		if !oneDataFrame {
+			ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
+			f, err := ct.readNonSettingsFrame()
+			if err != nil {
+				return err
+			}
+			wuf, ok := f.(*WindowUpdateFrame)
+			if !ok || wuf.Increment != 5000 {
+				return fmt.Errorf("want WindowUpdateFrame for 5000 bytes; got %v", summarizeFrame(f))
+			}
+		}
+		if err := ct.writeReadPing(); err != nil {
+			return err
+		}
+		if got, want := ct.inflowWindow(0), initialInflow; got != want {
+			return fmt.Errorf("connection flow tokens = %v, want %v", got, want)
+		}
 		return nil
 	}
 	ct.run()
@@ -3133,6 +3191,8 @@
 			break
 		}
 
+		initialConnWindow := ct.inflowWindow(0)
+
 		var buf bytes.Buffer
 		enc := hpack.NewEncoder(&buf)
 		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
@@ -3143,24 +3203,18 @@
 			EndStream:     false,
 			BlockFragment: buf.Bytes(),
 		})
+		initialStreamWindow := ct.inflowWindow(hf.StreamID)
 		pad := make([]byte, 5)
 		ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
-
-		f, err := ct.readNonSettingsFrame()
-		if err != nil {
-			return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err)
+		if err := ct.writeReadPing(); err != nil {
+			return err
 		}
-		wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding
-		if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 {
-			return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
+		// Padding flow control should have been returned.
+		if got, want := ct.inflowWindow(0), initialConnWindow-5000; got != want {
+			t.Errorf("conn inflow window = %v, want %v", got, want)
 		}
-
-		f, err = ct.readNonSettingsFrame()
-		if err != nil {
-			return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err)
-		}
-		if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 {
-			return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
+		if got, want := ct.inflowWindow(hf.StreamID), initialStreamWindow-5000; got != want {
+			t.Errorf("stream inflow window = %v, want %v", got, want)
 		}
 		unblockClient <- true
 		return nil