http2: add DialTLS to Transport

This commit allows a client of Transport to supply their own Dial
function that assumes all TLS checks have been performed and the
returned net.Conn is an h2 ready client connection.

Change-Id: If35b5c47c3bd6912a990d6cd89feefa3303bb42b
Reviewed-on: https://go-review.googlesource.com/16289
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 2122c9b..46507aa 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -44,8 +44,18 @@
 // A Transport internally caches connections to servers. It is safe
 // for concurrent use by multiple goroutines.
 type Transport struct {
-	// TODO: remove this and make more general with a TLS dial hook, like http
-	InsecureTLSDial bool
+	// DialTLS specifies an optional dial function for creating
+	// TLS connections for requests.
+	//
+	// If DialTLS is nil, tls.Dial is used.
+	//
+	// If the returned net.Conn has a ConnectionState method like tls.Conn,
+	// it will be used to set http.Response.TLS.
+	DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error)
+
+	// TLSClientConfig specifies the TLS configuration to use with
+	// tls.Client. If nil, the default configuration is used.
+	TLSClientConfig *tls.Config
 
 	// TODO: switch to RWMutex
 	// TODO: add support for sharing conns based on cert names
@@ -58,7 +68,7 @@
 // HTTP/2 server.
 type clientConn struct {
 	t        *Transport
-	tconn    *tls.Conn
+	tconn    net.Conn
 	tlsState *tls.ConnectionState
 	connKey  []string // key(s) this connection is cached in, in t.conns
 
@@ -213,14 +223,13 @@
 // The addr maybe be either "host" or "host:port".
 func (t *Transport) AddIdleConn(addr string, c *tls.Conn) error {
 	var key string
-	host, _, err := net.SplitHostPort(addr)
+	_, _, err := net.SplitHostPort(addr)
 	if err == nil {
 		key = addr
 	} else {
-		host = addr
 		key = addr + ":443"
 	}
-	cc, err := t.newClientConn(host, key, c)
+	cc, err := t.newClientConn(key, c)
 	if err != nil {
 		return err
 	}
@@ -263,34 +272,54 @@
 }
 
 func (t *Transport) dialClientConn(host, port, key string) (*clientConn, error) {
-	cfg := &tls.Config{
-		ServerName:         host,
-		NextProtos:         []string{NextProtoTLS},
-		InsecureSkipVerify: t.InsecureTLSDial,
-	}
-	tconn, err := tls.Dial("tcp", net.JoinHostPort(host, port), cfg)
+	tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
 	if err != nil {
 		return nil, err
 	}
-	return t.newClientConn(host, key, tconn)
+	return t.newClientConn(key, tconn)
 }
 
-func (t *Transport) newClientConn(host, key string, tconn *tls.Conn) (*clientConn, error) {
-	if err := tconn.Handshake(); err != nil {
+func (t *Transport) newTLSConfig(host string) *tls.Config {
+	cfg := new(tls.Config)
+	if t.TLSClientConfig != nil {
+		*cfg = *t.TLSClientConfig
+	}
+	cfg.NextProtos = []string{NextProtoTLS} // TODO: don't override if already in list
+	cfg.ServerName = host
+	return cfg
+}
+
+func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
+	if t.DialTLS != nil {
+		return t.DialTLS
+	}
+	return t.dialTLSDefault
+}
+
+func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
+	cn, err := tls.Dial(network, addr, cfg)
+	if err != nil {
 		return nil, err
 	}
-	if !t.InsecureTLSDial {
-		if err := tconn.VerifyHostname(host); err != nil {
+	if err := cn.Handshake(); err != nil {
+		return nil, err
+	}
+	if !cfg.InsecureSkipVerify {
+		if err := cn.VerifyHostname(cfg.ServerName); err != nil {
 			return nil, err
 		}
 	}
-	state := tconn.ConnectionState()
+	state := cn.ConnectionState()
 	if p := state.NegotiatedProtocol; p != NextProtoTLS {
 		return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
 	}
 	if !state.NegotiatedProtocolIsMutual {
 		return nil, errors.New("http2: could not negotiate protocol mutually")
 	}
+	return cn, nil
+}
+
+func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, error) {
 	if _, err := tconn.Write(clientPreface); err != nil {
 		return nil, err
 	}
@@ -299,7 +328,6 @@
 		t:                    t,
 		tconn:                tconn,
 		connKey:              []string{key}, // TODO: cert's validated hostnames too
-		tlsState:             &state,
 		readerDone:           make(chan struct{}),
 		nextStreamID:         1,
 		maxFrameSize:         16 << 10, // spec default
@@ -316,6 +344,15 @@
 	cc.br = bufio.NewReader(tconn)
 	cc.fr = NewFramer(cc.bw, cc.br)
 	cc.henc = hpack.NewEncoder(&cc.hbuf)
+
+	type connectionStater interface {
+		ConnectionState() tls.ConnectionState
+	}
+	if cs, ok := tconn.(connectionStater); ok {
+		state := cs.ConnectionState()
+		cc.tlsState = &state
+	}
+
 	cc.fr.WriteSettings(
 		Setting{ID: SettingEnablePush, Val: 0},
 		Setting{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 7bed140..8bbf60a 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5,13 +5,16 @@
 package http2
 
 import (
+	"crypto/tls"
 	"flag"
 	"io"
 	"io/ioutil"
+	"net"
 	"net/http"
 	"os"
 	"reflect"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 )
@@ -22,14 +25,14 @@
 	insecure      = flag.Bool("insecure", false, "insecure TLS dials")
 )
 
+var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
+
 func TestTransportExternal(t *testing.T) {
 	if !*extNet {
 		t.Skip("skipping external network test")
 	}
 	req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
-	rt := &Transport{
-		InsecureTLSDial: *insecure,
-	}
+	rt := &Transport{TLSClientConfig: tlsConfigInsecure}
 	res, err := rt.RoundTrip(req)
 	if err != nil {
 		t.Fatalf("%v", err)
@@ -44,7 +47,7 @@
 	}, optOnlyServer)
 	defer st.Close()
 
-	tr := &Transport{InsecureTLSDial: true}
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 	defer tr.CloseIdleConnections()
 
 	req, err := http.NewRequest("GET", st.ts.URL, nil)
@@ -91,7 +94,7 @@
 		io.WriteString(w, r.RemoteAddr)
 	}, optOnlyServer)
 	defer st.Close()
-	tr := &Transport{InsecureTLSDial: true}
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 	defer tr.CloseIdleConnections()
 	get := func() string {
 		req, err := http.NewRequest("GET", st.ts.URL, nil)
@@ -136,9 +139,7 @@
 	requestMade := make(chan struct{})
 	go func() {
 		defer close(done)
-		tr := &Transport{
-			InsecureTLSDial: true,
-		}
+		tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 		req, err := http.NewRequest("GET", st.ts.URL, nil)
 		if err != nil {
 			t.Fatal(err)
@@ -182,9 +183,7 @@
 	)
 	defer st.Close()
 
-	tr := &Transport{
-		InsecureTLSDial: true,
-	}
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 	defer tr.CloseIdleConnections()
 	const body = "Some message"
 	req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(body))
@@ -204,3 +203,45 @@
 		t.Errorf("Read body = %q; want %q", got, body)
 	}
 }
+
+func TestTransportDialTLS(t *testing.T) {
+	var mu sync.Mutex // guards following
+	var gotReq, didDial bool
+
+	ts := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {
+			mu.Lock()
+			gotReq = true
+			mu.Unlock()
+		},
+		optOnlyServer,
+	)
+	defer ts.Close()
+	tr := &Transport{
+		DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
+			mu.Lock()
+			didDial = true
+			mu.Unlock()
+			cfg.InsecureSkipVerify = true
+			c, err := tls.Dial(netw, addr, cfg)
+			if err != nil {
+				return nil, err
+			}
+			return c, c.Handshake()
+		},
+	}
+	defer tr.CloseIdleConnections()
+	client := &http.Client{Transport: tr}
+	res, err := client.Get(ts.ts.URL)
+	if err != nil {
+		t.Fatal(err)
+	}
+	res.Body.Close()
+	mu.Lock()
+	if !gotReq {
+		t.Error("didn't get request")
+	}
+	if !didDial {
+		t.Error("didn't use dial hook")
+	}
+}