http2: add Transport.WriteByteTimeout
Add a Transport-level knob to set a timeout for writes to net.Conns.
If a write exceeds the timeout without making any progress (at least
one byte written), the connection is closed.
Fixes golang/go#48830.
Change-Id: If0f57996d11c92bced30e07d1e238cbf8994acb4
Reviewed-on: https://go-review.googlesource.com/c/net/+/354431
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 86b734a..7f376af 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -24,6 +24,7 @@
"net/http"
"net/http/httptrace"
"net/textproto"
+ "os"
"sort"
"strconv"
"strings"
@@ -130,6 +131,11 @@
// Defaults to 15s.
PingTimeout time.Duration
+ // WriteByteTimeout is the timeout after which the connection will be
+ // closed no data can be written to it. The timeout begins when data is
+ // available to write, and is extended whenever any bytes are written.
+ WriteByteTimeout time.Duration
+
// CountError, if non-nil, is called on HTTP/2 transport errors.
// It's intended to increment a metric for monitoring, such
// as an expvar or Prometheus metric.
@@ -393,17 +399,31 @@
}
type stickyErrWriter struct {
- w io.Writer
- err *error
+ conn net.Conn
+ timeout time.Duration
+ err *error
}
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
- n, err = sew.w.Write(p)
- *sew.err = err
- return
+ for {
+ if sew.timeout != 0 {
+ sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
+ }
+ nn, err := sew.conn.Write(p[n:])
+ n += nn
+ if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
+ // Keep extending the deadline so long as we're making progress.
+ continue
+ }
+ if sew.timeout != 0 {
+ sew.conn.SetWriteDeadline(time.Time{})
+ }
+ *sew.err = err
+ return n, err
+ }
}
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -658,7 +678,11 @@
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
- cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
+ cc.bw = bufio.NewWriter(stickyErrWriter{
+ conn: c,
+ timeout: t.WriteByteTimeout,
+ err: &cc.werr,
+ })
cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br)
if t.CountError != nil {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index f0868d6..967edef 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5736,3 +5736,73 @@
res.Body.Close()
pw.Close()
}
+
+func TestTransportWriteByteTimeout(t *testing.T) {
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {},
+ optOnlyServer,
+ )
+ defer st.Close()
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ _, c := net.Pipe()
+ return c, nil
+ },
+ WriteByteTimeout: 1 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ _, err := c.Get(st.ts.URL)
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
+ }
+}
+
+type slowWriteConn struct {
+ net.Conn
+ hasWriteDeadline bool
+}
+
+func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
+ c.hasWriteDeadline = !t.IsZero()
+ return nil
+}
+
+func (c *slowWriteConn) Write(b []byte) (n int, err error) {
+ if c.hasWriteDeadline && len(b) > 1 {
+ n, err = c.Conn.Write(b[:1])
+ if err != nil {
+ return n, err
+ }
+ return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
+ }
+ return c.Conn.Write(b)
+}
+
+func TestTransportSlowWrites(t *testing.T) {
+ st := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {},
+ optOnlyServer,
+ )
+ defer st.Close()
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ cfg.InsecureSkipVerify = true
+ c, err := tls.Dial(network, addr, cfg)
+ return &slowWriteConn{Conn: c}, err
+ },
+ WriteByteTimeout: 1 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ const bodySize = 1 << 20
+ resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+}