http2: use (*tls.Dialer).DialContext in dialTLS

This lets us propagate the request context into the TLS
handshake.

Related to CL 295370
Updates golang/go#32406

Change-Id: Ie10c301be19b57b4b3e46ac31bbe87679e1eebc7
Reviewed-on: https://go-review.googlesource.com/c/net/+/295173
Trust: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 3a67636..652bc11 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -7,7 +7,9 @@
 package http2
 
 import (
+	"context"
 	"crypto/tls"
+	"errors"
 	"net/http"
 	"sync"
 )
@@ -78,61 +80,69 @@
 		// It gets its own connection.
 		traceGetConn(req, addr)
 		const singleUse = true
-		cc, err := p.t.dialClientConn(addr, singleUse)
+		cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
 		if err != nil {
 			return nil, err
 		}
 		return cc, nil
 	}
-	p.mu.Lock()
-	for _, cc := range p.conns[addr] {
-		if st := cc.idleState(); st.canTakeNewRequest {
-			if p.shouldTraceGetConn(st) {
-				traceGetConn(req, addr)
+	for {
+		p.mu.Lock()
+		for _, cc := range p.conns[addr] {
+			if st := cc.idleState(); st.canTakeNewRequest {
+				if p.shouldTraceGetConn(st) {
+					traceGetConn(req, addr)
+				}
+				p.mu.Unlock()
+				return cc, nil
 			}
-			p.mu.Unlock()
-			return cc, nil
 		}
-	}
-	if !dialOnMiss {
+		if !dialOnMiss {
+			p.mu.Unlock()
+			return nil, ErrNoCachedConn
+		}
+		traceGetConn(req, addr)
+		call := p.getStartDialLocked(req.Context(), addr)
 		p.mu.Unlock()
-		return nil, ErrNoCachedConn
+		<-call.done
+		if shouldRetryDial(call, req) {
+			continue
+		}
+		return call.res, call.err
 	}
-	traceGetConn(req, addr)
-	call := p.getStartDialLocked(addr)
-	p.mu.Unlock()
-	<-call.done
-	return call.res, call.err
 }
 
 // dialCall is an in-flight Transport dial call to a host.
 type dialCall struct {
-	_    incomparable
-	p    *clientConnPool
+	_ incomparable
+	p *clientConnPool
+	// the context associated with the request
+	// that created this dialCall
+	ctx  context.Context
 	done chan struct{} // closed when done
 	res  *ClientConn   // valid after done is closed
 	err  error         // valid after done is closed
 }
 
 // requires p.mu is held.
-func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
+func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
 	if call, ok := p.dialing[addr]; ok {
 		// A dial is already in-flight. Don't start another.
 		return call
 	}
-	call := &dialCall{p: p, done: make(chan struct{})}
+	call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
 	if p.dialing == nil {
 		p.dialing = make(map[string]*dialCall)
 	}
 	p.dialing[addr] = call
-	go call.dial(addr)
+	go call.dial(call.ctx, addr)
 	return call
 }
 
 // run in its own goroutine.
-func (c *dialCall) dial(addr string) {
+func (c *dialCall) dial(ctx context.Context, addr string) {
 	const singleUse = false // shared conn
-	c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
+	c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
 	close(c.done)
 
 	c.p.mu.Lock()
@@ -276,3 +286,28 @@
 func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
 	return p.getClientConn(req, addr, noDialOnMiss)
 }
+
+// shouldRetryDial reports whether the current request should
+// retry dialing after the call finished unsuccessfully, for example
+// if the dial was canceled because of a context cancellation or
+// deadline expiry.
+func shouldRetryDial(call *dialCall, req *http.Request) bool {
+	if call.err == nil {
+		// No error, no need to retry
+		return false
+	}
+	if call.ctx == req.Context() {
+		// If the call has the same context as the request, the dial
+		// should not be retried, since any cancellation will have come
+		// from this request.
+		return false
+	}
+	if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
+		// If the call error is not because of a context cancellation or a deadline expiry,
+		// the dial should not be retried.
+		return false
+	}
+	// Only retry if the error is a context cancellation error or deadline expiry
+	// and the context associated with the call was canceled or expired.
+	return call.ctx.Err() != nil
+}
diff --git a/http2/transport.go b/http2/transport.go
index 7688d72..5ae89cf 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -564,12 +564,12 @@
 	return false
 }
 
-func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
+func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
 	host, _, err := net.SplitHostPort(addr)
 	if err != nil {
 		return nil, err
 	}
-	tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
+	tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
 	if err != nil {
 		return nil, err
 	}
@@ -590,34 +590,28 @@
 	return cfg
 }
 
-func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
+func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
 	if t.DialTLS != nil {
 		return t.DialTLS
 	}
-	return t.dialTLSDefault
-}
-
-func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
-	cn, err := tls.Dial(network, addr, cfg)
-	if err != nil {
-		return nil, err
-	}
-	if err := cn.Handshake(); err != nil {
-		return nil, err
-	}
-	if !cfg.InsecureSkipVerify {
-		if err := cn.VerifyHostname(cfg.ServerName); err != nil {
+	return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
+		dialer := &tls.Dialer{
+			Config: cfg,
+		}
+		cn, err := dialer.DialContext(ctx, network, addr)
+		if err != nil {
 			return nil, err
 		}
+		tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
+		state := tlsCn.ConnectionState()
+		if p := state.NegotiatedProtocol; p != NextProtoTLS {
+			return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
+		}
+		if !state.NegotiatedProtocolIsMutual {
+			return nil, errors.New("http2: could not negotiate protocol mutually")
+		}
+		return cn, nil
 	}
-	state := cn.ConnectionState()
-	if p := state.NegotiatedProtocol; p != NextProtoTLS {
-		return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
-	}
-	if !state.NegotiatedProtocolIsMutual {
-		return nil, errors.New("http2: could not negotiate protocol mutually")
-	}
-	return cn, nil
 }
 
 // disableKeepAlives reports whether connections should be closed as
diff --git a/http2/transport_go117_test.go b/http2/transport_go117_test.go
new file mode 100644
index 0000000..f5d4e0c
--- /dev/null
+++ b/http2/transport_go117_test.go
@@ -0,0 +1,169 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.17
+// +build go1.17
+
+package http2
+
+import (
+	"context"
+	"crypto/tls"
+	"errors"
+	"net/http"
+	"net/http/httptest"
+
+	"testing"
+)
+
+func TestTransportDialTLSContext(t *testing.T) {
+	blockCh := make(chan struct{})
+	serverTLSConfigFunc := func(ts *httptest.Server) {
+		ts.Config.TLSConfig = &tls.Config{
+			// Triggers the server to request the clients certificate
+			// during TLS handshake.
+			ClientAuth: tls.RequestClientCert,
+		}
+	}
+	ts := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {},
+		optOnlyServer,
+		serverTLSConfigFunc,
+	)
+	defer ts.Close()
+	tr := &Transport{
+		TLSClientConfig: &tls.Config{
+			GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+				// Tests that the context provided to `req` is
+				// passed into this function.
+				close(blockCh)
+				<-cri.Context().Done()
+				return nil, cri.Context().Err()
+			},
+			InsecureSkipVerify: true,
+		},
+	}
+	defer tr.CloseIdleConnections()
+	req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	req = req.WithContext(ctx)
+	errCh := make(chan error)
+	go func() {
+		defer close(errCh)
+		res, err := tr.RoundTrip(req)
+		if err != nil {
+			errCh <- err
+			return
+		}
+		res.Body.Close()
+	}()
+	// Wait for GetClientCertificate handler to be called
+	<-blockCh
+	// Cancel the context
+	cancel()
+	// Expect the cancellation error here
+	err = <-errCh
+	if err == nil {
+		t.Fatal("cancelling context during client certificate fetch did not error as expected")
+		return
+	}
+	if !errors.Is(err, context.Canceled) {
+		t.Fatalf("unexpected error returned after cancellation: %v", err)
+	}
+}
+
+// TestDialRaceResumesDial tests that, given two concurrent requests
+// to the same address, when the first Dial is interrupted because
+// the first request's context is cancelled, the second request
+// resumes the dial automatically.
+func TestDialRaceResumesDial(t *testing.T) {
+	blockCh := make(chan struct{})
+	serverTLSConfigFunc := func(ts *httptest.Server) {
+		ts.Config.TLSConfig = &tls.Config{
+			// Triggers the server to request the clients certificate
+			// during TLS handshake.
+			ClientAuth: tls.RequestClientCert,
+		}
+	}
+	ts := newServerTester(t,
+		func(w http.ResponseWriter, r *http.Request) {},
+		optOnlyServer,
+		serverTLSConfigFunc,
+	)
+	defer ts.Close()
+	tr := &Transport{
+		TLSClientConfig: &tls.Config{
+			GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
+				select {
+				case <-blockCh:
+					// If we already errored, return without error.
+					return &tls.Certificate{}, nil
+				default:
+				}
+				close(blockCh)
+				<-cri.Context().Done()
+				return nil, cri.Context().Err()
+			},
+			InsecureSkipVerify: true,
+		},
+	}
+	defer tr.CloseIdleConnections()
+	req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Create two requests with independent cancellation.
+	ctx1, cancel1 := context.WithCancel(context.Background())
+	defer cancel1()
+	req1 := req.WithContext(ctx1)
+	ctx2, cancel2 := context.WithCancel(context.Background())
+	defer cancel2()
+	req2 := req.WithContext(ctx2)
+	errCh := make(chan error)
+	go func() {
+		res, err := tr.RoundTrip(req1)
+		if err != nil {
+			errCh <- err
+			return
+		}
+		res.Body.Close()
+	}()
+	successCh := make(chan struct{})
+	go func() {
+		// Don't start request until first request
+		// has initiated the handshake.
+		<-blockCh
+		res, err := tr.RoundTrip(req2)
+		if err != nil {
+			errCh <- err
+			return
+		}
+		res.Body.Close()
+		// Close successCh to indicate that the second request
+		// made it to the server successfully.
+		close(successCh)
+	}()
+	// Wait for GetClientCertificate handler to be called
+	<-blockCh
+	// Cancel the context first
+	cancel1()
+	// Expect the cancellation error here
+	err = <-errCh
+	if err == nil {
+		t.Fatal("cancelling context during client certificate fetch did not error as expected")
+		return
+	}
+	if !errors.Is(err, context.Canceled) {
+		t.Fatalf("unexpected error returned after cancellation: %v", err)
+	}
+	select {
+	case err := <-errCh:
+		t.Fatalf("unexpected second error: %v", err)
+	case <-successCh:
+	}
+}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index c9c948c..7b13928 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3276,7 +3276,8 @@
 	defer st.Close()
 	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 	defer tr.CloseIdleConnections()
-	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+	ctx := context.Background()
+	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -4278,7 +4279,8 @@
 	defer st.Close()
 	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 	defer tr.CloseIdleConnections()
-	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+	ctx := context.Background()
+	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
 	req, err := http.NewRequest("GET", st.ts.URL, nil)
 	if err != nil {
 		t.Fatal(err)
@@ -4788,7 +4790,8 @@
 
 	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
 	defer tr.CloseIdleConnections()
-	cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
+	ctx := context.Background()
+	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
 	if err != nil {
 		t.Fatal(err)
 	}