http2: make Transport respect http1 Transport settings

The http2 Transport now respects the http1 Transport's
DisableCompression, DisableKeepAlives, and ResponseHeaderTimeout, if
the http2 and http1 Transports are wired up together, as they are in
the upcoming Go 1.6.

Updates golang/go#14008

Change-Id: I2f477f6fe5dbef9d0e5439dfc7f3ec2c0da7f296
Reviewed-on: https://go-review.googlesource.com/18721
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/http2/transport_test.go b/http2/transport_test.go
index dab483a..77766be 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -99,7 +99,6 @@
 	} else if string(slurp) != body {
 		t.Errorf("Body = %q; want %q", slurp, body)
 	}
-
 }
 
 func TestTransportReusesConns(t *testing.T) {
@@ -1318,3 +1317,225 @@
 	c := &http.Client{Transport: tr}
 	c.Get(st.ts.URL)
 }
+
+// Test that the http1 Transport.DisableKeepAlives option is respected
+// and connections are closed as soon as idle.
+// See golang.org/issue/14008
+func TestTransportDisableKeepAlives(t *testing.T) {
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {
+			io.WriteString(w, "hi")
+		},
+		optOnlyServer,
+	)
+	defer st.Close()
+
+	connClosed := make(chan struct{}) // closed on tls.Conn.Close
+	tr := &Transport{
+		t1: &http.Transport{
+			DisableKeepAlives: true,
+		},
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			tc, err := tls.Dial(network, addr, cfg)
+			if err != nil {
+				return nil, err
+			}
+			return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
+		},
+	}
+	c := &http.Client{Transport: tr}
+	res, err := c.Get(st.ts.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := ioutil.ReadAll(res.Body); err != nil {
+		t.Fatal(err)
+	}
+	defer res.Body.Close()
+
+	select {
+	case <-connClosed:
+	case <-time.After(1 * time.Second):
+		t.Errorf("timeout")
+	}
+
+}
+
+// Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
+// but when things are totally idle, it still needs to close.
+func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
+	const D = 25 * time.Millisecond
+	st := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {
+			time.Sleep(D)
+			io.WriteString(w, "hi")
+		},
+		optOnlyServer,
+	)
+	defer st.Close()
+
+	var dials int32
+	var conns sync.WaitGroup
+	tr := &Transport{
+		t1: &http.Transport{
+			DisableKeepAlives: true,
+		},
+		TLSClientConfig: tlsConfigInsecure,
+		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+			tc, err := tls.Dial(network, addr, cfg)
+			if err != nil {
+				return nil, err
+			}
+			atomic.AddInt32(&dials, 1)
+			conns.Add(1)
+			return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
+		},
+	}
+	c := &http.Client{Transport: tr}
+	var reqs sync.WaitGroup
+	const N = 20
+	for i := 0; i < N; i++ {
+		reqs.Add(1)
+		if i == N-1 {
+			// For the final request, try to make all the
+			// others close. This isn't verified in the
+			// count, other than the Log statement, since
+			// it's so timing dependent. This test is
+			// really to make sure we don't interrupt a
+			// valid request.
+			time.Sleep(D * 2)
+		}
+		go func() {
+			defer reqs.Done()
+			res, err := c.Get(st.ts.URL)
+			if err != nil {
+				t.Error(err)
+				return
+			}
+			if _, err := ioutil.ReadAll(res.Body); err != nil {
+				t.Error(err)
+				return
+			}
+			res.Body.Close()
+		}()
+	}
+	reqs.Wait()
+	conns.Wait()
+	t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
+}
+
+type noteCloseConn struct {
+	net.Conn
+	onceClose sync.Once
+	closefn   func()
+}
+
+func (c *noteCloseConn) Close() error {
+	c.onceClose.Do(c.closefn)
+	return c.Conn.Close()
+}
+
+func isTimeout(err error) bool {
+	switch err := err.(type) {
+	case nil:
+		return false
+	case *url.Error:
+		return isTimeout(err.Err)
+	case net.Error:
+		return err.Timeout()
+	}
+	return false
+}
+
+// Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
+func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
+	testTransportResponseHeaderTimeout(t, false)
+}
+func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
+	testTransportResponseHeaderTimeout(t, true)
+}
+
+func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
+	ct := newClientTester(t)
+	ct.tr.t1 = &http.Transport{
+		ResponseHeaderTimeout: 5 * time.Millisecond,
+	}
+	ct.client = func() error {
+		c := &http.Client{Transport: ct.tr}
+		var err error
+		var n int64
+		const bodySize = 4 << 20
+		if body {
+			_, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
+		} else {
+			_, err = c.Get("https://dummy.tld/")
+		}
+		if !isTimeout(err) {
+			t.Errorf("client expected timeout error; got %#v", err)
+		}
+		if body && n != bodySize {
+			t.Errorf("only read %d bytes of body; want %d", n, bodySize)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				t.Logf("ReadFrame: %v", err)
+				return nil
+			}
+			switch f := f.(type) {
+			case *DataFrame:
+				dataLen := len(f.Data())
+				if dataLen > 0 {
+					if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
+						return err
+					}
+					if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
+						return err
+					}
+				}
+			case *RSTStreamFrame:
+				if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
+					return nil
+				}
+			}
+		}
+		return nil
+	}
+	ct.run()
+}
+
+func TestTransportDisableCompression(t *testing.T) {
+	const body = "sup"
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		want := http.Header{
+			"User-Agent": []string{"Go-http-client/2.0"},
+		}
+		if !reflect.DeepEqual(r.Header, want) {
+			t.Errorf("request headers = %v; want %v", r.Header, want)
+		}
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{
+		TLSClientConfig: tlsConfigInsecure,
+		t1: &http.Transport{
+			DisableCompression: true,
+		},
+	}
+	defer tr.CloseIdleConnections()
+
+	req, err := http.NewRequest("GET", st.ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	res, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer res.Body.Close()
+}