Respect client's advertised flow control window size when writing from Handlers.
diff --git a/flow.go b/flow.go
index dbb211b..25b231f 100644
--- a/flow.go
+++ b/flow.go
@@ -33,10 +33,11 @@
return f.size
}
-// acquire decrements the flow control window by n bytes, blocking
-// until they're available in the window.
-// The return value is only interesting for tests.
-func (f *flow) acquire(n int32) (waited int) {
+// wait waits for between 1 and n bytes (inclusive) to be available
+// and returns the number of quota bytes decremented from the quota
+// and allowed to be written. The returned value will be 0 iff the
+// stream has been killed.
+func (f *flow) wait(n int32) (got int32) {
if n < 0 {
panic("negative acquire")
}
@@ -44,13 +45,16 @@
defer f.c.L.Unlock()
for {
if f.closed {
- return
+ return 0
}
- if f.size >= n {
- f.size -= n
- return
+ if f.size >= 1 {
+ got = f.size
+ if got > n {
+ got = n
+ }
+ f.size -= got
+ return got
}
- waited++
f.c.Wait()
}
}
diff --git a/flow_test.go b/flow_test.go
index 6d66861..284ccbd 100644
--- a/flow_test.go
+++ b/flow_test.go
@@ -15,12 +15,18 @@
if got, want := f.cur(), int32(10); got != want {
t.Fatalf("size = %d; want %d", got, want)
}
- if waits := f.acquire(1); waits != 0 {
- t.Errorf("waits = %d; want 0", waits)
+ if got, want := f.wait(1), int32(1); got != want {
+ t.Errorf("wait = %d; want %d", got, want)
}
if got, want := f.cur(), int32(9); got != want {
t.Fatalf("size = %d; want %d", got, want)
}
+ if got, want := f.wait(20), int32(9); got != want {
+ t.Errorf("wait = %d; want %d", got, want)
+ }
+ if got, want := f.cur(), int32(0); got != want {
+ t.Fatalf("size = %d; want %d", got, want)
+ }
// Wait for 10, which should block, so start a background goroutine
// to refill it.
@@ -28,8 +34,8 @@
time.Sleep(50 * time.Millisecond)
f.add(50)
}()
- if waits := f.acquire(10); waits != 1 {
- t.Errorf("waits for 50 = %d; want 0", waits)
+ if got, want := f.wait(1), int32(1); got != want {
+ t.Errorf("after block, got %d; want %d", got, want)
}
if got, want := f.cur(), int32(49); got != want {
@@ -69,13 +75,15 @@
time.Sleep(50 * time.Millisecond)
f.close()
}()
- donec := make(chan bool)
+ gotc := make(chan int32)
go func() {
- defer close(donec)
- f.acquire(10)
+ gotc <- f.wait(1)
}()
select {
- case <-donec:
+ case got := <-gotc:
+ if got != 0 {
+ t.Errorf("got %d; want 0", got)
+ }
case <-time.After(2 * time.Second):
t.Error("timeout")
}
diff --git a/server.go b/server.go
index a2de381..bba6557 100644
--- a/server.go
+++ b/server.go
@@ -35,6 +35,7 @@
var (
errClientDisconnected = errors.New("client disconnected")
errClosedBody = errors.New("body closed by handler")
+ errStreamBroken = errors.New("http2: stream broken")
)
var responseWriterStatePool = sync.Pool{
@@ -1516,6 +1517,11 @@
if len(chunk) > handlerChunkWriteSize {
chunk = chunk[:handlerChunkWriteSize]
}
+ allowedSize := rws.stream.flow.wait(int32(len(chunk)))
+ if allowedSize == 0 {
+ return n, errStreamBroken
+ }
+ chunk = chunk[:allowedSize]
p = p[len(chunk):]
isFinal := rws.handlerDone && len(p) == 0
err = rws.writeData(chunk, isFinal)
diff --git a/server_test.go b/server_test.go
index 9456693..fdc3c3f 100644
--- a/server_test.go
+++ b/server_test.go
@@ -1254,6 +1254,12 @@
return nil
}, func(st *serverTester) {
getSlash(st) // make the single request
+
+ // Give the handler quota to write:
+ if err := st.fr.WriteWindowUpdate(1, size); err != nil {
+ t.Fatal(err)
+ }
+
hf := st.wantHeaders()
if hf.StreamEnded() {
t.Fatal("unexpected END_STREAM flag")
@@ -1275,7 +1281,6 @@
df := st.wantData()
bytes += len(df.Data())
frames++
- // TODO: send WINDOW_UPDATE frames at the server to keep it from stalling
for _, b := range df.Data() {
if b != 'a' {
t.Fatal("non-'a' byte seen in DATA")
@@ -1294,6 +1299,61 @@
})
}
+// Test that the handler can't write more than the client allows
+func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
+ const size = 1 << 20
+ testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
+ w.(http.Flusher).Flush()
+ n, err := w.Write(bytes.Repeat([]byte("a"), size))
+ if err != nil {
+ return fmt.Errorf("Write error: %v", err)
+ }
+ if n != size {
+ return fmt.Errorf("wrong size %d from Write", n)
+ }
+ return nil
+ }, func(st *serverTester) {
+ // Set the window size to something explicit for this test.
+ // It's also how much initial data we expect.
+ const initWindowSize = 123
+ if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, initWindowSize}); err != nil {
+ t.Fatal(err)
+ }
+ st.wantSettingsAck()
+
+ getSlash(st) // make the single request
+
+ defer func() { st.fr.WriteRSTStream(1, ErrCodeCancel) }()
+
+ hf := st.wantHeaders()
+ if hf.StreamEnded() {
+ t.Fatal("unexpected END_STREAM flag")
+ }
+ if !hf.HeadersEnded() {
+ t.Fatal("want END_HEADERS flag")
+ }
+
+ df := st.wantData()
+ if got := len(df.Data()); got != initWindowSize {
+ t.Fatalf("Initial window size = %d but got DATA with %d bytes", initWindowSize, got)
+ }
+
+ for _, quota := range []int{1, 13, 127} {
+ if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
+ t.Fatal(err)
+ }
+ df := st.wantData()
+ if int(quota) != len(df.Data()) {
+ t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
+ }
+ }
+
+ if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
+ t.Fatal(err)
+ }
+ })
+}
+
func TestServer_Response_Automatic100Continue(t *testing.T) {
const msg = "foo"
const reply = "bar"