http2: add Transport.WriteByteTimeout

Add a Transport-level knob to set a timeout for writes to net.Conns.
If a write exceeds the timeout without making any progress (at least
one byte written), the connection is closed.

Fixes golang/go#48830.

Change-Id: If0f57996d11c92bced30e07d1e238cbf8994acb4
Reviewed-on: https://go-review.googlesource.com/c/net/+/354431
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 86b734a..7f376af 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -24,6 +24,7 @@
 	"net/http"
 	"net/http/httptrace"
 	"net/textproto"
+	"os"
 	"sort"
 	"strconv"
 	"strings"
@@ -130,6 +131,11 @@
 	// Defaults to 15s.
 	PingTimeout time.Duration
 
+	// WriteByteTimeout is the timeout after which the connection will be
+	// closed no data can be written to it. The timeout begins when data is
+	// available to write, and is extended whenever any bytes are written.
+	WriteByteTimeout time.Duration
+
 	// CountError, if non-nil, is called on HTTP/2 transport errors.
 	// It's intended to increment a metric for monitoring, such
 	// as an expvar or Prometheus metric.
@@ -393,17 +399,31 @@
 }
 
 type stickyErrWriter struct {
-	w   io.Writer
-	err *error
+	conn    net.Conn
+	timeout time.Duration
+	err     *error
 }
 
 func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
 	if *sew.err != nil {
 		return 0, *sew.err
 	}
-	n, err = sew.w.Write(p)
-	*sew.err = err
-	return
+	for {
+		if sew.timeout != 0 {
+			sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
+		}
+		nn, err := sew.conn.Write(p[n:])
+		n += nn
+		if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
+			// Keep extending the deadline so long as we're making progress.
+			continue
+		}
+		if sew.timeout != 0 {
+			sew.conn.SetWriteDeadline(time.Time{})
+		}
+		*sew.err = err
+		return n, err
+	}
 }
 
 // noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -658,7 +678,11 @@
 
 	// TODO: adjust this writer size to account for frame size +
 	// MTU + crypto/tls record padding.
-	cc.bw = bufio.NewWriter(stickyErrWriter{c, &cc.werr})
+	cc.bw = bufio.NewWriter(stickyErrWriter{
+		conn:    c,
+		timeout: t.WriteByteTimeout,
+		err:     &cc.werr,
+	})
 	cc.br = bufio.NewReader(c)
 	cc.fr = NewFramer(cc.bw, cc.br)
 	if t.CountError != nil {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index f0868d6..967edef 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5736,3 +5736,73 @@
 	res.Body.Close()
 	pw.Close()
 }
+
+func TestTransportWriteByteTimeout(t *testing.T) {
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {},
+		optOnlyServer,
+	)
+	defer st.Close()
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			_, c := net.Pipe()
+			return c, nil
+		},
+		WriteByteTimeout: 1 * time.Millisecond,
+	}
+	defer tr.CloseIdleConnections()
+	c := &http.Client{Transport: tr}
+
+	_, err := c.Get(st.ts.URL)
+	if !errors.Is(err, os.ErrDeadlineExceeded) {
+		t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
+	}
+}
+
+type slowWriteConn struct {
+	net.Conn
+	hasWriteDeadline bool
+}
+
+func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
+	c.hasWriteDeadline = !t.IsZero()
+	return nil
+}
+
+func (c *slowWriteConn) Write(b []byte) (n int, err error) {
+	if c.hasWriteDeadline && len(b) > 1 {
+		n, err = c.Conn.Write(b[:1])
+		if err != nil {
+			return n, err
+		}
+		return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
+	}
+	return c.Conn.Write(b)
+}
+
+func TestTransportSlowWrites(t *testing.T) {
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {},
+		optOnlyServer,
+	)
+	defer st.Close()
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			cfg.InsecureSkipVerify = true
+			c, err := tls.Dial(network, addr, cfg)
+			return &slowWriteConn{Conn: c}, err
+		},
+		WriteByteTimeout: 1 * time.Millisecond,
+	}
+	defer tr.CloseIdleConnections()
+	c := &http.Client{Transport: tr}
+
+	const bodySize = 1 << 20
+	resp, err := c.Post(st.ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
+	if err != nil {
+		t.Fatal(err)
+	}
+	resp.Body.Close()
+}