http2: add Server.WriteByteTimeout

Transports support a WriteByteTimeout option which sets the maximum
amount of time we can go without being able to write any bytes to
a connection. Add an equivalent option to Server for consistency.

Fixes golang/go#61777

Change-Id: Iaa8a69dfc403906eb224829320f901e5a6a5c429
Reviewed-on: https://go-review.googlesource.com/c/net/+/601496
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/connframes_test.go b/http2/connframes_test.go
index 7db8b74..1f7834c 100644
--- a/http2/connframes_test.go
+++ b/http2/connframes_test.go
@@ -6,7 +6,6 @@
 
 import (
 	"bytes"
-	"context"
 	"io"
 	"net/http"
 	"os"
@@ -295,7 +294,7 @@
 	if err == nil {
 		tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
 	}
-	if err == context.DeadlineExceeded {
+	if err == os.ErrDeadlineExceeded {
 		tf.t.Fatalf("connection is not closed; want it to be")
 	}
 }
@@ -306,7 +305,7 @@
 	if err == nil {
 		tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
 	}
-	if err != context.DeadlineExceeded {
+	if err != os.ErrDeadlineExceeded {
 		tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
 	}
 }
diff --git a/http2/http2.go b/http2/http2.go
index 003e649..7688c35 100644
--- a/http2/http2.go
+++ b/http2/http2.go
@@ -19,8 +19,9 @@
 	"bufio"
 	"context"
 	"crypto/tls"
+	"errors"
 	"fmt"
-	"io"
+	"net"
 	"net/http"
 	"os"
 	"sort"
@@ -237,13 +238,19 @@
 // Its buffered writer is lazily allocated as needed, to minimize
 // idle memory usage with many connections.
 type bufferedWriter struct {
-	_  incomparable
-	w  io.Writer     // immutable
-	bw *bufio.Writer // non-nil when data is buffered
+	_           incomparable
+	group       synctestGroupInterface // immutable
+	conn        net.Conn               // immutable
+	bw          *bufio.Writer          // non-nil when data is buffered
+	byteTimeout time.Duration          // immutable, WriteByteTimeout
 }
 
-func newBufferedWriter(w io.Writer) *bufferedWriter {
-	return &bufferedWriter{w: w}
+func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
+	return &bufferedWriter{
+		group:       group,
+		conn:        conn,
+		byteTimeout: timeout,
+	}
 }
 
 // bufWriterPoolBufferSize is the size of bufio.Writer's
@@ -270,7 +277,7 @@
 func (w *bufferedWriter) Write(p []byte) (n int, err error) {
 	if w.bw == nil {
 		bw := bufWriterPool.Get().(*bufio.Writer)
-		bw.Reset(w.w)
+		bw.Reset((*bufferedWriterTimeoutWriter)(w))
 		w.bw = bw
 	}
 	return w.bw.Write(p)
@@ -288,6 +295,38 @@
 	return err
 }
 
+type bufferedWriterTimeoutWriter bufferedWriter
+
+func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
+	return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
+}
+
+// writeWithByteTimeout writes to conn.
+// If more than timeout passes without any bytes being written to the connection,
+// the write fails.
+func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
+	if timeout <= 0 {
+		return conn.Write(p)
+	}
+	for {
+		var now time.Time
+		if group == nil {
+			now = time.Now()
+		} else {
+			now = group.Now()
+		}
+		conn.SetWriteDeadline(now.Add(timeout))
+		nn, err := conn.Write(p[n:])
+		n += nn
+		if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
+			// Either we finished the write, made no progress, or hit the deadline.
+			// Whichever it is, we're done now.
+			conn.SetWriteDeadline(time.Time{})
+			return n, err
+		}
+	}
+}
+
 func mustUint31(v int32) uint32 {
 	if v < 0 || v > 2147483647 {
 		panic("out of range")
diff --git a/http2/server.go b/http2/server.go
index 6c349f3..b16173c 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -127,6 +127,12 @@
 	// If zero or negative, there is no timeout.
 	IdleTimeout time.Duration
 
+	// WriteByteTimeout is the timeout after which a connection will be
+	// closed if no data can be written to it. The timeout begins when data is
+	// available to write, and is extended whenever any bytes are written.
+	// If zero or negative, there is no timeout.
+	WriteByteTimeout time.Duration
+
 	// MaxUploadBufferPerConnection is the size of the initial flow
 	// control window for each connections. The HTTP/2 spec does not
 	// allow this to be smaller than 65535 or larger than 2^32-1.
@@ -446,7 +452,7 @@
 		conn:                        c,
 		baseCtx:                     baseCtx,
 		remoteAddrStr:               c.RemoteAddr().String(),
-		bw:                          newBufferedWriter(c),
+		bw:                          newBufferedWriter(s.group, c, s.WriteByteTimeout),
 		handler:                     opts.handler(),
 		streams:                     make(map[uint32]*stream),
 		readFrameCh:                 make(chan readFrameResult),
@@ -1320,6 +1326,10 @@
 	sc.writingFrame = false
 	sc.writingFrameAsync = false
 
+	if res.err != nil {
+		sc.conn.Close()
+	}
+
 	wr := res.wr
 
 	if writeEndsStream(wr.write) {
diff --git a/http2/server_test.go b/http2/server_test.go
index 47c3c61..ab53c26 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -4674,3 +4674,35 @@
 	}
 	resp.Body.Close()
 }
+
+func TestServerWriteByteTimeout(t *testing.T) {
+	const timeout = 1 * time.Second
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		w.Write(make([]byte, 100))
+	}, func(s *Server) {
+		s.WriteByteTimeout = timeout
+	})
+	st.greet()
+
+	st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      1,
+		BlockFragment: st.encodeHeader(),
+		EndStream:     true,
+		EndHeaders:    true,
+	})
+
+	// Read a few bytes, staying just under WriteByteTimeout.
+	for i := 0; i < 10; i++ {
+		st.advance(timeout - 1)
+		if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
+			t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
+		}
+	}
+
+	// Wait for WriteByteTimeout.
+	// The connection should close.
+	st.advance(1 * time.Second) // timeout after writing one byte
+	st.advance(1 * time.Second) // timeout after failing to write any more bytes
+	st.wantClosed()
+}
diff --git a/http2/transport.go b/http2/transport.go
index 61f511f..49fc792 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -25,7 +25,6 @@
 	"net/http"
 	"net/http/httptrace"
 	"net/textproto"
-	"os"
 	"sort"
 	"strconv"
 	"strings"
@@ -499,6 +498,7 @@
 }
 
 type stickyErrWriter struct {
+	group   synctestGroupInterface
 	conn    net.Conn
 	timeout time.Duration
 	err     *error
@@ -508,22 +508,9 @@
 	if *sew.err != nil {
 		return 0, *sew.err
 	}
-	for {
-		if sew.timeout != 0 {
-			sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
-		}
-		nn, err := sew.conn.Write(p[n:])
-		n += nn
-		if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
-			// Keep extending the deadline so long as we're making progress.
-			continue
-		}
-		if sew.timeout != 0 {
-			sew.conn.SetWriteDeadline(time.Time{})
-		}
-		*sew.err = err
-		return n, err
-	}
+	n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
+	*sew.err = err
+	return n, err
 }
 
 // noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -792,10 +779,12 @@
 		pings:                 make(map[[8]byte]chan struct{}),
 		reqHeaderMu:           make(chan struct{}, 1),
 	}
+	var group synctestGroupInterface
 	if t.transportTestHooks != nil {
 		t.markNewGoroutine()
 		t.transportTestHooks.newclientconn(cc)
 		c = cc.tconn
+		group = t.group
 	}
 	if VerboseLogs {
 		t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -807,6 +796,7 @@
 	// TODO: adjust this writer size to account for frame size +
 	// MTU + crypto/tls record padding.
 	cc.bw = bufio.NewWriter(stickyErrWriter{
+		group:   group,
 		conn:    c,
 		timeout: t.WriteByteTimeout,
 		err:     &cc.werr,