http2: add Server.WriteByteTimeout
Transports support a WriteByteTimeout option which sets the maximum
amount of time we can go without being able to write any bytes to
a connection. Add an equivalent option to Server for consistency.
Fixes golang/go#61777
Change-Id: Iaa8a69dfc403906eb224829320f901e5a6a5c429
Reviewed-on: https://go-review.googlesource.com/c/net/+/601496
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/connframes_test.go b/http2/connframes_test.go
index 7db8b74..1f7834c 100644
--- a/http2/connframes_test.go
+++ b/http2/connframes_test.go
@@ -6,7 +6,6 @@
import (
"bytes"
- "context"
"io"
"net/http"
"os"
@@ -295,7 +294,7 @@
if err == nil {
tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
}
- if err == context.DeadlineExceeded {
+ if err == os.ErrDeadlineExceeded {
tf.t.Fatalf("connection is not closed; want it to be")
}
}
@@ -306,7 +305,7 @@
if err == nil {
tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
}
- if err != context.DeadlineExceeded {
+ if err != os.ErrDeadlineExceeded {
tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
}
}
diff --git a/http2/http2.go b/http2/http2.go
index 003e649..7688c35 100644
--- a/http2/http2.go
+++ b/http2/http2.go
@@ -19,8 +19,9 @@
"bufio"
"context"
"crypto/tls"
+ "errors"
"fmt"
- "io"
+ "net"
"net/http"
"os"
"sort"
@@ -237,13 +238,19 @@
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
- _ incomparable
- w io.Writer // immutable
- bw *bufio.Writer // non-nil when data is buffered
+ _ incomparable
+ group synctestGroupInterface // immutable
+ conn net.Conn // immutable
+ bw *bufio.Writer // non-nil when data is buffered
+ byteTimeout time.Duration // immutable, WriteByteTimeout
}
-func newBufferedWriter(w io.Writer) *bufferedWriter {
- return &bufferedWriter{w: w}
+func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
+ return &bufferedWriter{
+ group: group,
+ conn: conn,
+ byteTimeout: timeout,
+ }
}
// bufWriterPoolBufferSize is the size of bufio.Writer's
@@ -270,7 +277,7 @@
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
- bw.Reset(w.w)
+ bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw
}
return w.bw.Write(p)
@@ -288,6 +295,38 @@
return err
}
+type bufferedWriterTimeoutWriter bufferedWriter
+
+func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
+ return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
+}
+
+// writeWithByteTimeout writes to conn.
+// If more than timeout passes without any bytes being written to the connection,
+// the write fails.
+func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
+ if timeout <= 0 {
+ return conn.Write(p)
+ }
+ for {
+ var now time.Time
+ if group == nil {
+ now = time.Now()
+ } else {
+ now = group.Now()
+ }
+ conn.SetWriteDeadline(now.Add(timeout))
+ nn, err := conn.Write(p[n:])
+ n += nn
+ if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
+ // Either we finished the write, made no progress, or hit the deadline.
+ // Whichever it is, we're done now.
+ conn.SetWriteDeadline(time.Time{})
+ return n, err
+ }
+ }
+}
+
func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 {
panic("out of range")
diff --git a/http2/server.go b/http2/server.go
index 6c349f3..b16173c 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -127,6 +127,12 @@
// If zero or negative, there is no timeout.
IdleTimeout time.Duration
+ // WriteByteTimeout is the timeout after which a connection will be
+ // closed if no data can be written to it. The timeout begins when data is
+ // available to write, and is extended whenever any bytes are written.
+ // If zero or negative, there is no timeout.
+ WriteByteTimeout time.Duration
+
// MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1.
@@ -446,7 +452,7 @@
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
- bw: newBufferedWriter(c),
+ bw: newBufferedWriter(s.group, c, s.WriteByteTimeout),
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
@@ -1320,6 +1326,10 @@
sc.writingFrame = false
sc.writingFrameAsync = false
+ if res.err != nil {
+ sc.conn.Close()
+ }
+
wr := res.wr
if writeEndsStream(wr.write) {
diff --git a/http2/server_test.go b/http2/server_test.go
index 47c3c61..ab53c26 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -4674,3 +4674,35 @@
}
resp.Body.Close()
}
+
+func TestServerWriteByteTimeout(t *testing.T) {
+ const timeout = 1 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.Write(make([]byte, 100))
+ }, func(s *Server) {
+ s.WriteByteTimeout = timeout
+ })
+ st.greet()
+
+ st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+
+ // Read a few bytes, staying just under WriteByteTimeout.
+ for i := 0; i < 10; i++ {
+ st.advance(timeout - 1)
+ if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
+ t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
+ }
+ }
+
+ // Wait for WriteByteTimeout.
+ // The connection should close.
+ st.advance(1 * time.Second) // timeout after writing one byte
+ st.advance(1 * time.Second) // timeout after failing to write any more bytes
+ st.wantClosed()
+}
diff --git a/http2/transport.go b/http2/transport.go
index 61f511f..49fc792 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -25,7 +25,6 @@
"net/http"
"net/http/httptrace"
"net/textproto"
- "os"
"sort"
"strconv"
"strings"
@@ -499,6 +498,7 @@
}
type stickyErrWriter struct {
+ group synctestGroupInterface
conn net.Conn
timeout time.Duration
err *error
@@ -508,22 +508,9 @@
if *sew.err != nil {
return 0, *sew.err
}
- for {
- if sew.timeout != 0 {
- sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
- }
- nn, err := sew.conn.Write(p[n:])
- n += nn
- if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
- // Keep extending the deadline so long as we're making progress.
- continue
- }
- if sew.timeout != 0 {
- sew.conn.SetWriteDeadline(time.Time{})
- }
- *sew.err = err
- return n, err
- }
+ n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
+ *sew.err = err
+ return n, err
}
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -792,10 +779,12 @@
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
}
+ var group synctestGroupInterface
if t.transportTestHooks != nil {
t.markNewGoroutine()
t.transportTestHooks.newclientconn(cc)
c = cc.tconn
+ group = t.group
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -807,6 +796,7 @@
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(stickyErrWriter{
+ group: group,
conn: c,
timeout: t.WriteByteTimeout,
err: &cc.werr,