http2: add DialTLS to Transport
This commit allows a client of Transport to supply their own Dial
function that assumes all TLS checks have been performed and the
returned net.Conn is an h2 ready client connection.
Change-Id: If35b5c47c3bd6912a990d6cd89feefa3303bb42b
Reviewed-on: https://go-review.googlesource.com/16289
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 2122c9b..46507aa 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -44,8 +44,18 @@
// A Transport internally caches connections to servers. It is safe
// for concurrent use by multiple goroutines.
type Transport struct {
- // TODO: remove this and make more general with a TLS dial hook, like http
- InsecureTLSDial bool
+ // DialTLS specifies an optional dial function for creating
+ // TLS connections for requests.
+ //
+ // If DialTLS is nil, tls.Dial is used.
+ //
+ // If the returned net.Conn has a ConnectionState method like tls.Conn,
+ // it will be used to set http.Response.TLS.
+ DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error)
+
+ // TLSClientConfig specifies the TLS configuration to use with
+ // tls.Client. If nil, the default configuration is used.
+ TLSClientConfig *tls.Config
// TODO: switch to RWMutex
// TODO: add support for sharing conns based on cert names
@@ -58,7 +68,7 @@
// HTTP/2 server.
type clientConn struct {
t *Transport
- tconn *tls.Conn
+ tconn net.Conn
tlsState *tls.ConnectionState
connKey []string // key(s) this connection is cached in, in t.conns
@@ -213,14 +223,13 @@
// The addr maybe be either "host" or "host:port".
func (t *Transport) AddIdleConn(addr string, c *tls.Conn) error {
var key string
- host, _, err := net.SplitHostPort(addr)
+ _, _, err := net.SplitHostPort(addr)
if err == nil {
key = addr
} else {
- host = addr
key = addr + ":443"
}
- cc, err := t.newClientConn(host, key, c)
+ cc, err := t.newClientConn(key, c)
if err != nil {
return err
}
@@ -263,34 +272,54 @@
}
func (t *Transport) dialClientConn(host, port, key string) (*clientConn, error) {
- cfg := &tls.Config{
- ServerName: host,
- NextProtos: []string{NextProtoTLS},
- InsecureSkipVerify: t.InsecureTLSDial,
- }
- tconn, err := tls.Dial("tcp", net.JoinHostPort(host, port), cfg)
+ tconn, err := t.dialTLS()("tcp", net.JoinHostPort(host, port), t.newTLSConfig(host))
if err != nil {
return nil, err
}
- return t.newClientConn(host, key, tconn)
+ return t.newClientConn(key, tconn)
}
-func (t *Transport) newClientConn(host, key string, tconn *tls.Conn) (*clientConn, error) {
- if err := tconn.Handshake(); err != nil {
+func (t *Transport) newTLSConfig(host string) *tls.Config {
+ cfg := new(tls.Config)
+ if t.TLSClientConfig != nil {
+ *cfg = *t.TLSClientConfig
+ }
+ cfg.NextProtos = []string{NextProtoTLS} // TODO: don't override if already in list
+ cfg.ServerName = host
+ return cfg
+}
+
+func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
+ if t.DialTLS != nil {
+ return t.DialTLS
+ }
+ return t.dialTLSDefault
+}
+
+func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ cn, err := tls.Dial(network, addr, cfg)
+ if err != nil {
return nil, err
}
- if !t.InsecureTLSDial {
- if err := tconn.VerifyHostname(host); err != nil {
+ if err := cn.Handshake(); err != nil {
+ return nil, err
+ }
+ if !cfg.InsecureSkipVerify {
+ if err := cn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
- state := tconn.ConnectionState()
+ state := cn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("http2: could not negotiate protocol mutually")
}
+ return cn, nil
+}
+
+func (t *Transport) newClientConn(key string, tconn net.Conn) (*clientConn, error) {
if _, err := tconn.Write(clientPreface); err != nil {
return nil, err
}
@@ -299,7 +328,6 @@
t: t,
tconn: tconn,
connKey: []string{key}, // TODO: cert's validated hostnames too
- tlsState: &state,
readerDone: make(chan struct{}),
nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default
@@ -316,6 +344,15 @@
cc.br = bufio.NewReader(tconn)
cc.fr = NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf)
+
+ type connectionStater interface {
+ ConnectionState() tls.ConnectionState
+ }
+ if cs, ok := tconn.(connectionStater); ok {
+ state := cs.ConnectionState()
+ cc.tlsState = &state
+ }
+
cc.fr.WriteSettings(
Setting{ID: SettingEnablePush, Val: 0},
Setting{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 7bed140..8bbf60a 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5,13 +5,16 @@
package http2
import (
+ "crypto/tls"
"flag"
"io"
"io/ioutil"
+ "net"
"net/http"
"os"
"reflect"
"strings"
+ "sync"
"testing"
"time"
)
@@ -22,14 +25,14 @@
insecure = flag.Bool("insecure", false, "insecure TLS dials")
)
+var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
+
func TestTransportExternal(t *testing.T) {
if !*extNet {
t.Skip("skipping external network test")
}
req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
- rt := &Transport{
- InsecureTLSDial: *insecure,
- }
+ rt := &Transport{TLSClientConfig: tlsConfigInsecure}
res, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("%v", err)
@@ -44,7 +47,7 @@
}, optOnlyServer)
defer st.Close()
- tr := &Transport{InsecureTLSDial: true}
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", st.ts.URL, nil)
@@ -91,7 +94,7 @@
io.WriteString(w, r.RemoteAddr)
}, optOnlyServer)
defer st.Close()
- tr := &Transport{InsecureTLSDial: true}
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
get := func() string {
req, err := http.NewRequest("GET", st.ts.URL, nil)
@@ -136,9 +139,7 @@
requestMade := make(chan struct{})
go func() {
defer close(done)
- tr := &Transport{
- InsecureTLSDial: true,
- }
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
@@ -182,9 +183,7 @@
)
defer st.Close()
- tr := &Transport{
- InsecureTLSDial: true,
- }
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
const body = "Some message"
req, err := http.NewRequest("POST", st.ts.URL, strings.NewReader(body))
@@ -204,3 +203,45 @@
t.Errorf("Read body = %q; want %q", got, body)
}
}
+
+func TestTransportDialTLS(t *testing.T) {
+ var mu sync.Mutex // guards following
+ var gotReq, didDial bool
+
+ ts := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ },
+ optOnlyServer,
+ )
+ defer ts.Close()
+ tr := &Transport{
+ DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
+ mu.Lock()
+ didDial = true
+ mu.Unlock()
+ cfg.InsecureSkipVerify = true
+ c, err := tls.Dial(netw, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ return c, c.Handshake()
+ },
+ }
+ defer tr.CloseIdleConnections()
+ client := &http.Client{Transport: tr}
+ res, err := client.Get(ts.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if !didDial {
+ t.Error("didn't use dial hook")
+ }
+}