http2: use (*tls.Dialer).DialContext in dialTLS
This lets us propagate the request context into the TLS
handshake.
Related to CL 295370
Updates golang/go#32406
Change-Id: Ie10c301be19b57b4b3e46ac31bbe87679e1eebc7
Reviewed-on: https://go-review.googlesource.com/c/net/+/295173
Trust: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 3a67636..652bc11 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -7,7 +7,9 @@
package http2
import (
+ "context"
"crypto/tls"
+ "errors"
"net/http"
"sync"
)
@@ -78,61 +80,69 @@
// It gets its own connection.
traceGetConn(req, addr)
const singleUse = true
- cc, err := p.t.dialClientConn(addr, singleUse)
+ cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
if err != nil {
return nil, err
}
return cc, nil
}
- p.mu.Lock()
- for _, cc := range p.conns[addr] {
- if st := cc.idleState(); st.canTakeNewRequest {
- if p.shouldTraceGetConn(st) {
- traceGetConn(req, addr)
+ for {
+ p.mu.Lock()
+ for _, cc := range p.conns[addr] {
+ if st := cc.idleState(); st.canTakeNewRequest {
+ if p.shouldTraceGetConn(st) {
+ traceGetConn(req, addr)
+ }
+ p.mu.Unlock()
+ return cc, nil
}
- p.mu.Unlock()
- return cc, nil
}
- }
- if !dialOnMiss {
+ if !dialOnMiss {
+ p.mu.Unlock()
+ return nil, ErrNoCachedConn
+ }
+ traceGetConn(req, addr)
+ call := p.getStartDialLocked(req.Context(), addr)
p.mu.Unlock()
- return nil, ErrNoCachedConn
+ <-call.done
+ if shouldRetryDial(call, req) {
+ continue
+ }
+ return call.res, call.err
}
- traceGetConn(req, addr)
- call := p.getStartDialLocked(addr)
- p.mu.Unlock()
- <-call.done
- return call.res, call.err
}
// dialCall is an in-flight Transport dial call to a host.
type dialCall struct {
- _ incomparable
- p *clientConnPool
+ _ incomparable
+ p *clientConnPool
+ // the context associated with the request
+ // that created this dialCall
+ ctx context.Context
done chan struct{} // closed when done
res *ClientConn // valid after done is closed
err error // valid after done is closed
}
// requires p.mu is held.
-func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
+func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another.
return call
}
- call := &dialCall{p: p, done: make(chan struct{})}
+ call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
if p.dialing == nil {
p.dialing = make(map[string]*dialCall)
}
p.dialing[addr] = call
- go call.dial(addr)
+ go call.dial(call.ctx, addr)
return call
}
// run in its own goroutine.
-func (c *dialCall) dial(addr string) {
+func (c *dialCall) dial(ctx context.Context, addr string) {
const singleUse = false // shared conn
- c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
+ c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
close(c.done)
c.p.mu.Lock()
@@ -276,3 +286,28 @@
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
return p.getClientConn(req, addr, noDialOnMiss)
}
+
+// shouldRetryDial reports whether the current request should
+// retry dialing after the call finished unsuccessfully, for example
+// if the dial was canceled because of a context cancellation or
+// deadline expiry.
+func shouldRetryDial(call *dialCall, req *http.Request) bool {
+ if call.err == nil {
+ // No error, no need to retry
+ return false
+ }
+ if call.ctx == req.Context() {
+ // If the call has the same context as the request, the dial
+ // should not be retried, since any cancellation will have come
+ // from this request.
+ return false
+ }
+ if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
+ // If the call error is not because of a context cancellation or a deadline expiry,
+ // the dial should not be retried.
+ return false
+ }
+ // Only retry if the error is a context cancellation error or deadline expiry
+ // and the context associated with the call was canceled or expired.
+ return call.ctx.Err() != nil
+}
diff --git a/http2/transport.go b/http2/transport.go
index 7688d72..5ae89cf 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -564,12 +564,12 @@
return false
}
-func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
+func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
- tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
+ tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
if err != nil {
return nil, err
}
@@ -590,34 +590,28 @@
return cfg
}
-func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
+func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
- return t.dialTLSDefault
-}
-
-func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
- cn, err := tls.Dial(network, addr, cfg)
- if err != nil {
- return nil, err
- }
- if err := cn.Handshake(); err != nil {
- return nil, err
- }
- if !cfg.InsecureSkipVerify {
- if err := cn.VerifyHostname(cfg.ServerName); err != nil {
+ return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ dialer := &tls.Dialer{
+ Config: cfg,
+ }
+ cn, err := dialer.DialContext(ctx, network, addr)
+ if err != nil {
return nil, err
}
+ tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
+ state := tlsCn.ConnectionState()
+ if p := state.NegotiatedProtocol; p != NextProtoTLS {
+ return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
+ }
+ if !state.NegotiatedProtocolIsMutual {
+ return nil, errors.New("http2: could not negotiate protocol mutually")
+ }
+ return cn, nil
}
- state := cn.ConnectionState()
- if p := state.NegotiatedProtocol; p != NextProtoTLS {
- return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
- }
- if !state.NegotiatedProtocolIsMutual {
- return nil, errors.New("http2: could not negotiate protocol mutually")
- }
- return cn, nil
}
// disableKeepAlives reports whether connections should be closed as
diff --git a/http2/transport_go117_test.go b/http2/transport_go117_test.go
new file mode 100644
index 0000000..f5d4e0c
--- /dev/null
+++ b/http2/transport_go117_test.go
@@ -0,0 +1,169 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.17
+// +build go1.17
+
+package http2
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+
+ "testing"
+)
+
+func TestTransportDialTLSContext(t *testing.T) {
+ blockCh := make(chan struct{})
+ serverTLSConfigFunc := func(ts *httptest.Server) {
+ ts.Config.TLSConfig = &tls.Config{
+ // Triggers the server to request the clients certificate
+ // during TLS handshake.
+ ClientAuth: tls.RequestClientCert,
+ }
+ }
+ ts := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {},
+ optOnlyServer,
+ serverTLSConfigFunc,
+ )
+ defer ts.Close()
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ // Tests that the context provided to `req` is
+ // passed into this function.
+ close(blockCh)
+ <-cri.Context().Done()
+ return nil, cri.Context().Err()
+ },
+ InsecureSkipVerify: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ req = req.WithContext(ctx)
+ errCh := make(chan error)
+ go func() {
+ defer close(errCh)
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ res.Body.Close()
+ }()
+ // Wait for GetClientCertificate handler to be called
+ <-blockCh
+ // Cancel the context
+ cancel()
+ // Expect the cancellation error here
+ err = <-errCh
+ if err == nil {
+ t.Fatal("cancelling context during client certificate fetch did not error as expected")
+ return
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("unexpected error returned after cancellation: %v", err)
+ }
+}
+
+// TestDialRaceResumesDial tests that, given two concurrent requests
+// to the same address, when the first Dial is interrupted because
+// the first request's context is cancelled, the second request
+// resumes the dial automatically.
+func TestDialRaceResumesDial(t *testing.T) {
+ blockCh := make(chan struct{})
+ serverTLSConfigFunc := func(ts *httptest.Server) {
+ ts.Config.TLSConfig = &tls.Config{
+ // Triggers the server to request the clients certificate
+ // during TLS handshake.
+ ClientAuth: tls.RequestClientCert,
+ }
+ }
+ ts := newServerTester(t,
+ func(w http.ResponseWriter, r *http.Request) {},
+ optOnlyServer,
+ serverTLSConfigFunc,
+ )
+ defer ts.Close()
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+ select {
+ case <-blockCh:
+ // If we already errored, return without error.
+ return &tls.Certificate{}, nil
+ default:
+ }
+ close(blockCh)
+ <-cri.Context().Done()
+ return nil, cri.Context().Err()
+ },
+ InsecureSkipVerify: true,
+ },
+ }
+ defer tr.CloseIdleConnections()
+ req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Create two requests with independent cancellation.
+ ctx1, cancel1 := context.WithCancel(context.Background())
+ defer cancel1()
+ req1 := req.WithContext(ctx1)
+ ctx2, cancel2 := context.WithCancel(context.Background())
+ defer cancel2()
+ req2 := req.WithContext(ctx2)
+ errCh := make(chan error)
+ go func() {
+ res, err := tr.RoundTrip(req1)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ res.Body.Close()
+ }()
+ successCh := make(chan struct{})
+ go func() {
+ // Don't start request until first request
+ // has initiated the handshake.
+ <-blockCh
+ res, err := tr.RoundTrip(req2)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ res.Body.Close()
+ // Close successCh to indicate that the second request
+ // made it to the server successfully.
+ close(successCh)
+ }()
+ // Wait for GetClientCertificate handler to be called
+ <-blockCh
+ // Cancel the context first
+ cancel1()
+ // Expect the cancellation error here
+ err = <-errCh
+ if err == nil {
+ t.Fatal("cancelling context during client certificate fetch did not error as expected")
+ return
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("unexpected error returned after cancellation: %v", err)
+ }
+ select {
+ case err := <-errCh:
+ t.Fatalf("unexpected second error: %v", err)
+ case <-successCh:
+ }
+}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index c9c948c..7b13928 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3276,7 +3276,8 @@
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+ ctx := context.Background()
+ cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
@@ -4278,7 +4279,8 @@
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+ ctx := context.Background()
+ cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
@@ -4788,7 +4790,8 @@
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
- cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+ ctx := context.Background()
+ cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}