http2: add ConfigureTransport, like ConfigureServer

To make an net/http.Transport (http1) be HTTP2-ified.

Move this code out of net/http internals, so others can use it.
And add tests.

Some reflection and +build tag work is required so it is a no-op and
returns an error for Go 1.5 and below.

Change-Id: I539f233509602009e6c1179b810ff509a1f83ae3
Reviewed-on: https://go-review.googlesource.com/16734
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/configure_transport.go b/http2/configure_transport.go
new file mode 100644
index 0000000..fb8979a
--- /dev/null
+++ b/http2/configure_transport.go
@@ -0,0 +1,78 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.6
+
+package http2
+
+import (
+	"crypto/tls"
+	"fmt"
+	"net/http"
+)
+
+func configureTransport(t1 *http.Transport) error {
+	connPool := new(clientConnPool)
+	t2 := &Transport{ConnPool: noDialClientConnPool{connPool}}
+	if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
+		return err
+	}
+	if t1.TLSClientConfig == nil {
+		t1.TLSClientConfig = new(tls.Config)
+	}
+	if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
+		t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
+	}
+	upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
+		cc, err := t2.NewClientConn(c)
+		if err != nil {
+			c.Close()
+			return erringRoundTripper{err}
+		}
+		connPool.addConn(authorityAddr(authority), cc)
+		return t2
+	}
+	if m := t1.TLSNextProto; len(m) == 0 {
+		t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
+			"h2": upgradeFn,
+		}
+	} else {
+		m["h2"] = upgradeFn
+	}
+	return nil
+}
+
+// registerHTTPSProtocol calls Transport.RegisterProtocol but
+// convering panics into errors.
+func registerHTTPSProtocol(t *http.Transport, rt http.RoundTripper) (err error) {
+	defer func() {
+		if e := recover(); e != nil {
+			err = fmt.Errorf("%v", e)
+		}
+	}()
+	t.RegisterProtocol("https", rt)
+	return nil
+}
+
+// noDialClientConnPool is an implementation of http2.ClientConnPool
+// which never dials.  We let the HTTP/1.1 client dial and use its TLS
+// connection instead.
+type noDialClientConnPool struct{ *clientConnPool }
+
+func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
+	const doDial = false
+	return p.getClientConn(req, addr, doDial)
+}
+
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
+// if there's already has a cached connection to the host.
+type noDialH2RoundTripper struct{ t *Transport }
+
+func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	res, err := rt.t.RoundTrip(req)
+	if err == ErrNoCachedConn {
+		return nil, http.ErrSkipAltProtocol
+	}
+	return res, err
+}
diff --git a/http2/go15.go b/http2/go15.go
new file mode 100644
index 0000000..dbf6033
--- /dev/null
+++ b/http2/go15.go
@@ -0,0 +1,13 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.6
+
+package http2
+
+import "net/http"
+
+func configureTransport(t1 *http.Transport) error {
+	return errTransportVersion
+}
diff --git a/http2/transport.go b/http2/transport.go
index fa8a2a4..9d53d44 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -65,6 +65,15 @@
 	connPoolOrDef ClientConnPool // non-nil version of ConnPool
 }
 
+var errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
+
+// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
+// It requires Go 1.6 or later and returns an error if the net/http package is too old
+// or if t1 has already been HTTP/2-enabled.
+func ConfigureTransport(t1 *http.Transport) error {
+	return configureTransport(t1) // in configure_transport.go (go1.6) or go15.go
+}
+
 func (t *Transport) connPool() ClientConnPool {
 	t.connPoolOnce.Do(t.initConnPool)
 	return t.connPoolOrDef
@@ -1072,3 +1081,16 @@
 }
 
 var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
+
+func strSliceContains(ss []string, s string) bool {
+	for _, v := range ss {
+		if v == s {
+			return true
+		}
+	}
+	return false
+}
+
+type erringRoundTripper struct{ err error }
+
+func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err }
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 6b23563..b7385d6 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -330,3 +330,46 @@
 		t.Error("didn't use dial hook")
 	}
 }
+
+func TestConfigureTransport(t *testing.T) {
+	t1 := &http.Transport{}
+	err := ConfigureTransport(t1)
+	if err == errTransportVersion {
+		t.Skip(err)
+	}
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got := fmt.Sprintf("%#v", *t1); !strings.Contains(got, `"h2"`) {
+		// Laziness, to avoid buildtags.
+		t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
+	}
+	if t1.TLSClientConfig == nil {
+		t.Errorf("nil t1.TLSClientConfig")
+	} else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, []string{"h2"}) {
+		t.Errorf("TLSClientConfig.NextProtos = %q; want just 'h2'", t1.TLSClientConfig.NextProtos)
+	}
+	if err := ConfigureTransport(t1); err == nil {
+		t.Error("unexpected success on second call to ConfigureTransport")
+	}
+
+	// And does it work?
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, r.Proto)
+	}, optOnlyServer)
+	defer st.Close()
+
+	t1.TLSClientConfig.InsecureSkipVerify = true
+	c := &http.Client{Transport: t1}
+	res, err := c.Get(st.ts.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	slurp, err := ioutil.ReadAll(res.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got, want := string(slurp), "HTTP/2.0"; got != want {
+		t.Errorf("body = %q; want %q", got, want)
+	}
+}