http2: make Transport respect http1 Transport settings

The http2 Transport now respects the http1 Transport's
DisableCompression, DisableKeepAlives, and ResponseHeaderTimeout, if
the http2 and http1 Transports are wired up together, as they are in
the upcoming Go 1.6.

Updates golang/go#14008

Change-Id: I2f477f6fe5dbef9d0e5439dfc7f3ec2c0da7f296
Reviewed-on: https://go-review.googlesource.com/18721
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/http2/configure_transport.go b/http2/configure_transport.go
index 0edcc34..daa17f5 100644
--- a/http2/configure_transport.go
+++ b/http2/configure_transport.go
@@ -14,7 +14,11 @@
 
 func configureTransport(t1 *http.Transport) (*Transport, error) {
 	connPool := new(clientConnPool)
-	t2 := &Transport{ConnPool: noDialClientConnPool{connPool}}
+	t2 := &Transport{
+		ConnPool: noDialClientConnPool{connPool},
+		t1:       t1,
+	}
+	connPool.t = t2
 	if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
 		return nil, err
 	}
diff --git a/http2/http2.go b/http2/http2.go
index f4a599c..f75e0ce 100644
--- a/http2/http2.go
+++ b/http2/http2.go
@@ -285,3 +285,14 @@
 	}
 	return true
 }
+
+type httpError struct {
+	msg     string
+	timeout bool
+}
+
+func (e *httpError) Error() string   { return e.msg }
+func (e *httpError) Timeout() bool   { return e.timeout }
+func (e *httpError) Temporary() bool { return true }
+
+var errTimeout error = &httpError{msg: "http2: timeout awaiting response headers", timeout: true}
diff --git a/http2/transport.go b/http2/transport.go
index 6d12725..ed159e1 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -22,6 +22,7 @@
 	"strconv"
 	"strings"
 	"sync"
+	"time"
 
 	"golang.org/x/net/http2/hpack"
 )
@@ -84,6 +85,11 @@
 	// to mean no limit.
 	MaxHeaderListSize uint32
 
+	// t1, if non-nil, is the standard library Transport using
+	// this transport. Its settings are used (but not its
+	// RoundTrip method, etc).
+	t1 *http.Transport
+
 	connPoolOnce  sync.Once
 	connPoolOrDef ClientConnPool // non-nil version of ConnPool
 }
@@ -99,12 +105,7 @@
 }
 
 func (t *Transport) disableCompression() bool {
-	if t.DisableCompression {
-		return true
-	}
-	// TODO: also disable if this transport is somehow linked to an http1 Transport
-	// and it's configured there?
-	return false
+	return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
 }
 
 var errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
@@ -160,7 +161,7 @@
 	henc                 *hpack.Encoder
 	freeBuf              [][]byte
 
-	wmu  sync.Mutex // held while writing; acquire AFTER wmu if holding both
+	wmu  sync.Mutex // held while writing; acquire AFTER mu if holding both
 	werr error      // first write error that has occurred
 }
 
@@ -178,7 +179,7 @@
 	inflow      flow  // guarded by cc.mu
 	bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
 	readErr     error // sticky read error; owned by transportResponseBody.Read
-	stopReqBody bool  // stop writing req body; guarded by cc.mu
+	stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
 
 	peerReset chan struct{} // closed on peer reset
 	resetErr  error         // populated before peerReset is closed
@@ -221,10 +222,13 @@
 	}
 }
 
-func (cs *clientStream) abortRequestBodyWrite() {
+func (cs *clientStream) abortRequestBodyWrite(err error) {
+	if err == nil {
+		panic("nil error")
+	}
 	cc := cs.cc
 	cc.mu.Lock()
-	cs.stopReqBody = true
+	cs.stopReqBody = err
 	cc.cond.Broadcast()
 	cc.mu.Unlock()
 }
@@ -364,6 +368,12 @@
 	return cn, nil
 }
 
+// disableKeepAlives reports whether connections should be closed as
+// soon as possible.
+func (t *Transport) disableKeepAlives() bool {
+	return t.t1 != nil && t.t1.DisableKeepAlives
+}
+
 func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 	if VerboseLogs {
 		t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
@@ -463,7 +473,7 @@
 }
 
 func (cc *ClientConn) canTakeNewRequestLocked() bool {
-	return cc.goAway == nil &&
+	return cc.goAway == nil && !cc.closed &&
 		int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
 		cc.nextStreamID < 2147483647
 }
@@ -544,6 +554,17 @@
 	return "", nil
 }
 
+func (cc *ClientConn) responseHeaderTimeout() time.Duration {
+	if cc.t.t1 != nil {
+		return cc.t.t1.ResponseHeaderTimeout
+	}
+	// No way to do this (yet?) with just an http2.Transport. Probably
+	// no need. Request.Cancel this is the new way. We only need to support
+	// this for compatibility with the old http.Transport fields when
+	// we're doing transparent http2.
+	return 0
+}
+
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	trailers, err := commaSeparatedTrailers(req)
 	if err != nil {
@@ -623,17 +644,25 @@
 		return nil, werr
 	}
 
+	var respHeaderTimer <-chan time.Time
 	var bodyCopyErrc chan error // result of body copy
 	if hasBody {
 		bodyCopyErrc = make(chan error, 1)
 		go func() {
 			bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
 		}()
+	} else {
+		if d := cc.responseHeaderTimeout(); d != 0 {
+			timer := time.NewTimer(d)
+			defer timer.Stop()
+			respHeaderTimer = timer.C
+		}
 	}
 
 	readLoopResCh := cs.resc
 	requestCanceledCh := requestCancel(req)
-	requestCanceled := false
+	bodyWritten := false
+
 	for {
 		select {
 		case re := <-readLoopResCh:
@@ -648,7 +677,7 @@
 				// doesn't, they'll RST_STREAM us soon enough.  This is a
 				// heuristic to avoid adding knobs to Transport.  Hopefully
 				// we can keep it.
-				cs.abortRequestBodyWrite()
+				cs.abortRequestBodyWrite(errStopReqBodyWrite)
 			}
 			if re.err != nil {
 				cc.forgetStreamID(cs.ID)
@@ -657,37 +686,37 @@
 			res.Request = req
 			res.TLS = cc.tlsState
 			return res, nil
+		case <-respHeaderTimer:
+			cc.forgetStreamID(cs.ID)
+			if !hasBody || bodyWritten {
+				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+			} else {
+				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
+			}
+			return nil, errTimeout
 		case <-requestCanceledCh:
 			cc.forgetStreamID(cs.ID)
-			cs.abortRequestBodyWrite()
-			if !hasBody {
+			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
-				return nil, errRequestCanceled
+			} else {
+				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
-			// If we have a body, wait for the body write to be
-			// finished before sending the RST_STREAM frame.
-			requestCanceled = true
-			requestCanceledCh = nil // to prevent spins
-			readLoopResCh = nil     // ignore responses at this point
+			return nil, errRequestCanceled
 		case <-cs.peerReset:
-			if requestCanceled {
-				// They hung up on us first. No need to write a RST_STREAM.
-				// But prioritize the request canceled error value, since
-				// it's likely related. (same spirit as http1 code)
-				return nil, errRequestCanceled
-			}
 			// processResetStream already removed the
 			// stream from the streams map; no need for
 			// forgetStreamID.
 			return nil, cs.resetErr
 		case err := <-bodyCopyErrc:
-			if requestCanceled {
-				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
-				return nil, errRequestCanceled
-			}
 			if err != nil {
 				return nil, err
 			}
+			bodyWritten = true
+			if d := cc.responseHeaderTimeout(); d != 0 {
+				timer := time.NewTimer(d)
+				defer timer.Stop()
+				respHeaderTimer = timer.C
+			}
 		}
 	}
 }
@@ -723,9 +752,14 @@
 	return cc.werr
 }
 
-// errAbortReqBodyWrite is an internal error value.
-// It doesn't escape to callers.
-var errAbortReqBodyWrite = errors.New("http2: aborting request body write")
+// internal error values; they don't escape to callers
+var (
+	// abort request body write; don't send cancel
+	errStopReqBodyWrite = errors.New("http2: aborting request body write")
+
+	// abort request body write, but send stream reset of cancel.
+	errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
+)
 
 func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
 	cc := cs.cc
@@ -761,7 +795,13 @@
 		for len(remain) > 0 && err == nil {
 			var allowed int32
 			allowed, err = cs.awaitFlowControl(len(remain))
-			if err != nil {
+			switch {
+			case err == errStopReqBodyWrite:
+				return err
+			case err == errStopReqBodyWriteAndCancel:
+				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+				return err
+			case err != nil:
 				return err
 			}
 			cc.wmu.Lock()
@@ -821,8 +861,8 @@
 		if cc.closed {
 			return 0, errClientConnClosed
 		}
-		if cs.stopReqBody {
-			return 0, errAbortReqBodyWrite
+		if cs.stopReqBody != nil {
+			return 0, cs.stopReqBody
 		}
 		if err := cs.checkReset(); err != nil {
 			return 0, err
@@ -898,7 +938,7 @@
 			cc.writeHeader(lowKey, v)
 		}
 	}
-	if contentLength >= 0 {
+	if shouldSendReqContentLength(req.Method, contentLength) {
 		cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
 	}
 	if addGzipHeader {
@@ -910,6 +950,28 @@
 	return cc.hbuf.Bytes()
 }
 
+// shouldSendReqContentLength reports whether the http2.Transport should send
+// a "content-length" request header. This logic is basically a copy of the net/http
+// transferWriter.shouldSendContentLength.
+// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
+// -1 means unknown.
+func shouldSendReqContentLength(method string, contentLength int64) bool {
+	if contentLength > 0 {
+		return true
+	}
+	if contentLength < 0 {
+		return false
+	}
+	// For zero bodies, whether we send a content-length depends on the method.
+	// It also kinda doesn't matter for http2 either way, with END_STREAM.
+	switch method {
+	case "POST", "PUT", "PATCH":
+		return true
+	default:
+		return false
+	}
+}
+
 // requires cc.mu be held.
 func (cc *ClientConn) encodeTrailers(req *http.Request) []byte {
 	cc.hbuf.Reset()
@@ -1032,6 +1094,8 @@
 
 func (rl *clientConnReadLoop) run() error {
 	cc := rl.cc
+	closeWhenIdle := cc.t.disableKeepAlives()
+	gotReply := false // ever saw a reply
 	for {
 		f, err := cc.fr.ReadFrame()
 		if err != nil {
@@ -1046,18 +1110,25 @@
 		if VerboseLogs {
 			cc.vlogf("http2: Transport received %s", summarizeFrame(f))
 		}
+		maybeClose := false // whether frame might transition us to idle
 
 		switch f := f.(type) {
 		case *HeadersFrame:
 			err = rl.processHeaders(f)
+			maybeClose = true
+			gotReply = true
 		case *ContinuationFrame:
 			err = rl.processContinuation(f)
+			maybeClose = true
 		case *DataFrame:
 			err = rl.processData(f)
+			maybeClose = true
 		case *GoAwayFrame:
 			err = rl.processGoAway(f)
+			maybeClose = true
 		case *RSTStreamFrame:
 			err = rl.processResetStream(f)
+			maybeClose = true
 		case *SettingsFrame:
 			err = rl.processSettings(f)
 		case *PushPromiseFrame:
@@ -1072,6 +1143,9 @@
 		if err != nil {
 			return err
 		}
+		if closeWhenIdle && gotReply && maybeClose && len(rl.activeRes) == 0 {
+			cc.closeIfIdle()
+		}
 	}
 }
 
diff --git a/http2/transport_test.go b/http2/transport_test.go
index dab483a..77766be 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -99,7 +99,6 @@
 	} else if string(slurp) != body {
 		t.Errorf("Body = %q; want %q", slurp, body)
 	}
-
 }
 
 func TestTransportReusesConns(t *testing.T) {
@@ -1318,3 +1317,225 @@
 	c := &http.Client{Transport: tr}
 	c.Get(st.ts.URL)
 }
+
+// Test that the http1 Transport.DisableKeepAlives option is respected
+// and connections are closed as soon as idle.
+// See golang.org/issue/14008
+func TestTransportDisableKeepAlives(t *testing.T) {
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {
+			io.WriteString(w, "hi")
+		},
+		optOnlyServer,
+	)
+	defer st.Close()
+
+	connClosed := make(chan struct{}) // closed on tls.Conn.Close
+	tr := &Transport{
+		t1: &http.Transport{
+			DisableKeepAlives: true,
+		},
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			tc, err := tls.Dial(network, addr, cfg)
+			if err != nil {
+				return nil, err
+			}
+			return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
+		},
+	}
+	c := &http.Client{Transport: tr}
+	res, err := c.Get(st.ts.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := ioutil.ReadAll(res.Body); err != nil {
+		t.Fatal(err)
+	}
+	defer res.Body.Close()
+
+	select {
+	case <-connClosed:
+	case <-time.After(1 * time.Second):
+		t.Errorf("timeout")
+	}
+
+}
+
+// Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
+// but when things are totally idle, it still needs to close.
+func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
+	const D = 25 * time.Millisecond
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {
+			time.Sleep(D)
+			io.WriteString(w, "hi")
+		},
+		optOnlyServer,
+	)
+	defer st.Close()
+
+	var dials int32
+	var conns sync.WaitGroup
+	tr := &Transport{
+		t1: &http.Transport{
+			DisableKeepAlives: true,
+		},
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			tc, err := tls.Dial(network, addr, cfg)
+			if err != nil {
+				return nil, err
+			}
+			atomic.AddInt32(&dials, 1)
+			conns.Add(1)
+			return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
+		},
+	}
+	c := &http.Client{Transport: tr}
+	var reqs sync.WaitGroup
+	const N = 20
+	for i := 0; i < N; i++ {
+		reqs.Add(1)
+		if i == N-1 {
+			// For the final request, try to make all the
+			// others close. This isn't verified in the
+			// count, other than the Log statement, since
+			// it's so timing dependent. This test is
+			// really to make sure we don't interrupt a
+			// valid request.
+			time.Sleep(D * 2)
+		}
+		go func() {
+			defer reqs.Done()
+			res, err := c.Get(st.ts.URL)
+			if err != nil {
+				t.Error(err)
+				return
+			}
+			if _, err := ioutil.ReadAll(res.Body); err != nil {
+				t.Error(err)
+				return
+			}
+			res.Body.Close()
+		}()
+	}
+	reqs.Wait()
+	conns.Wait()
+	t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
+}
+
+type noteCloseConn struct {
+	net.Conn
+	onceClose sync.Once
+	closefn   func()
+}
+
+func (c *noteCloseConn) Close() error {
+	c.onceClose.Do(c.closefn)
+	return c.Conn.Close()
+}
+
+func isTimeout(err error) bool {
+	switch err := err.(type) {
+	case nil:
+		return false
+	case *url.Error:
+		return isTimeout(err.Err)
+	case net.Error:
+		return err.Timeout()
+	}
+	return false
+}
+
+// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
+func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
+	testTransportResponseHeaderTimeout(t, false)
+}
+func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
+	testTransportResponseHeaderTimeout(t, true)
+}
+
+func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
+	ct := newClientTester(t)
+	ct.tr.t1 = &http.Transport{
+		ResponseHeaderTimeout: 5 * time.Millisecond,
+	}
+	ct.client = func() error {
+		c := &http.Client{Transport: ct.tr}
+		var err error
+		var n int64
+		const bodySize = 4 << 20
+		if body {
+			_, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
+		} else {
+			_, err = c.Get("https://dummy.tld/")
+		}
+		if !isTimeout(err) {
+			t.Errorf("client expected timeout error; got %#v", err)
+		}
+		if body && n != bodySize {
+			t.Errorf("only read %d bytes of body; want %d", n, bodySize)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				t.Logf("ReadFrame: %v", err)
+				return nil
+			}
+			switch f := f.(type) {
+			case *DataFrame:
+				dataLen := len(f.Data())
+				if dataLen > 0 {
+					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
+						return err
+					}
+					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
+						return err
+					}
+				}
+			case *RSTStreamFrame:
+				if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
+					return nil
+				}
+			}
+		}
+		return nil
+	}
+	ct.run()
+}
+
+func TestTransportDisableCompression(t *testing.T) {
+	const body = "sup"
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		want := http.Header{
+			"User-Agent": []string{"Go-http-client/2.0"},
+		}
+		if !reflect.DeepEqual(r.Header, want) {
+			t.Errorf("request headers = %v; want %v", r.Header, want)
+		}
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+		t1: &http.Transport{
+			DisableCompression: true,
+		},
+	}
+	defer tr.CloseIdleConnections()
+
+	req, err := http.NewRequest("GET", st.ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	res, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer res.Body.Close()
+}