crypto/tls: fix duplicate calls to VerifyConnection

Also add a test that could reproduce this error and
ensure it doesn't occur in other configurations.

Fixes #39012

Change-Id: If792b5131f312c269fd2c5f08c9ed5c00188d1af
Reviewed-on: https://go-review.googlesource.com/c/go/+/233957
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go
index de93e1b..88c974f 100644
--- a/src/crypto/tls/handshake_client_test.go
+++ b/src/crypto/tls/handshake_client_test.go
@@ -1464,6 +1464,225 @@
 	}
 }
 
+func TestVerifyConnection(t *testing.T) {
+	t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
+	t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
+}
+
+func testVerifyConnection(t *testing.T, version uint16) {
+	checkFields := func(c ConnectionState, called *int) error {
+		if c.Version != version {
+			return fmt.Errorf("got Version %v, want %v", c.Version, version)
+		}
+		if c.HandshakeComplete {
+			return fmt.Errorf("got HandshakeComplete, want false")
+		}
+		if c.ServerName != "example.golang" {
+			return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang")
+		}
+		if c.NegotiatedProtocol != "protocol1" {
+			return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1")
+		}
+		wantDidResume := false
+		if *called == 2 { // if this is the second time, then it should be a resumption
+			wantDidResume = true
+		}
+		if c.DidResume != wantDidResume {
+			return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume)
+		}
+		return nil
+	}
+
+	tests := []struct {
+		name            string
+		configureServer func(*Config, *int)
+		configureClient func(*Config, *int)
+	}{
+		{
+			name: "RequireAndVerifyClientCert",
+			configureServer: func(config *Config, called *int) {
+				config.ClientAuth = RequireAndVerifyClientCert
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					if l := len(c.PeerCertificates); l != 1 {
+						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+					}
+					if len(c.VerifiedChains) == 0 {
+						return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
+					}
+					return checkFields(c, called)
+				}
+			},
+			configureClient: func(config *Config, called *int) {
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					if l := len(c.PeerCertificates); l != 1 {
+						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+					}
+					if len(c.VerifiedChains) == 0 {
+						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+					}
+					if c.DidResume {
+						return nil
+						// The SCTs and OCSP Responce are dropped on resumption.
+						// See http://golang.org/issue/39075.
+					}
+					if len(c.OCSPResponse) == 0 {
+						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+					}
+					if len(c.SignedCertificateTimestamps) == 0 {
+						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+					}
+					return checkFields(c, called)
+				}
+			},
+		},
+		{
+			name: "InsecureSkipVerify",
+			configureServer: func(config *Config, called *int) {
+				config.ClientAuth = RequireAnyClientCert
+				config.InsecureSkipVerify = true
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					if l := len(c.PeerCertificates); l != 1 {
+						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+					}
+					if c.VerifiedChains != nil {
+						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+					}
+					return checkFields(c, called)
+				}
+			},
+			configureClient: func(config *Config, called *int) {
+				config.InsecureSkipVerify = true
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					if l := len(c.PeerCertificates); l != 1 {
+						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+					}
+					if c.VerifiedChains != nil {
+						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+					}
+					if c.DidResume {
+						return nil
+						// The SCTs and OCSP Responce are dropped on resumption.
+						// See http://golang.org/issue/39075.
+					}
+					if len(c.OCSPResponse) == 0 {
+						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+					}
+					if len(c.SignedCertificateTimestamps) == 0 {
+						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+					}
+					return checkFields(c, called)
+				}
+			},
+		},
+		{
+			name: "NoClientCert",
+			configureServer: func(config *Config, called *int) {
+				config.ClientAuth = NoClientCert
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					return checkFields(c, called)
+				}
+			},
+			configureClient: func(config *Config, called *int) {
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					return checkFields(c, called)
+				}
+			},
+		},
+		{
+			name: "RequestClientCert",
+			configureServer: func(config *Config, called *int) {
+				config.ClientAuth = RequestClientCert
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					return checkFields(c, called)
+				}
+			},
+			configureClient: func(config *Config, called *int) {
+				config.Certificates = nil // clear the client cert
+				config.VerifyConnection = func(c ConnectionState) error {
+					*called++
+					if l := len(c.PeerCertificates); l != 1 {
+						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+					}
+					if len(c.VerifiedChains) == 0 {
+						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+					}
+					if c.DidResume {
+						return nil
+						// The SCTs and OCSP Responce are dropped on resumption.
+						// See http://golang.org/issue/39075.
+					}
+					if len(c.OCSPResponse) == 0 {
+						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+					}
+					if len(c.SignedCertificateTimestamps) == 0 {
+						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+					}
+					return checkFields(c, called)
+				}
+			},
+		},
+	}
+	for _, test := range tests {
+		issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+		if err != nil {
+			panic(err)
+		}
+		rootCAs := x509.NewCertPool()
+		rootCAs.AddCert(issuer)
+
+		var serverCalled, clientCalled int
+
+		serverConfig := &Config{
+			MaxVersion:   version,
+			Certificates: []Certificate{testConfig.Certificates[0]},
+			ClientCAs:    rootCAs,
+			NextProtos:   []string{"protocol1"},
+		}
+		serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+		serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
+		test.configureServer(serverConfig, &serverCalled)
+
+		clientConfig := &Config{
+			MaxVersion:         version,
+			ClientSessionCache: NewLRUClientSessionCache(32),
+			RootCAs:            rootCAs,
+			ServerName:         "example.golang",
+			Certificates:       []Certificate{testConfig.Certificates[0]},
+			NextProtos:         []string{"protocol1"},
+		}
+		test.configureClient(clientConfig, &clientCalled)
+
+		testHandshakeState := func(name string, didResume bool) {
+			_, hs, err := testHandshake(t, clientConfig, serverConfig)
+			if err != nil {
+				t.Fatalf("%s: handshake failed: %s", name, err)
+			}
+			if hs.DidResume != didResume {
+				t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
+			}
+			wantCalled := 1
+			if didResume {
+				wantCalled = 2 // resumption would mean this is the second time it was called in this test
+			}
+			if clientCalled != wantCalled {
+				t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
+			}
+			if serverCalled != wantCalled {
+				t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
+			}
+		}
+		testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
+		testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
+	}
+}
+
 func TestVerifyPeerCertificate(t *testing.T) {
 	t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
 	t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go
index 57fba10..2c2f0a4 100644
--- a/src/crypto/tls/handshake_server.go
+++ b/src/crypto/tls/handshake_server.go
@@ -425,6 +425,13 @@
 		return err
 	}
 
+	if c.config.VerifyConnection != nil {
+		if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+			c.sendAlert(alertBadCertificate)
+			return err
+		}
+	}
+
 	hs.masterSecret = hs.sessionState.masterSecret
 
 	return nil
@@ -548,14 +555,11 @@
 		if err != nil {
 			return err
 		}
-	} else {
-		// Make sure the connection is still being verified whether or not
-		// the server requested a client certificate.
-		if c.config.VerifyConnection != nil {
-			if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
-				c.sendAlert(alertBadCertificate)
-				return err
-			}
+	}
+	if c.config.VerifyConnection != nil {
+		if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+			c.sendAlert(alertBadCertificate)
+			return err
 		}
 	}
 
@@ -805,13 +809,6 @@
 		}
 	}
 
-	if c.config.VerifyConnection != nil {
-		if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
-			c.sendAlert(alertBadCertificate)
-			return err
-		}
-	}
-
 	return nil
 }
 
diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go
index fb7f871..92d55e0 100644
--- a/src/crypto/tls/handshake_server_tls13.go
+++ b/src/crypto/tls/handshake_server_tls13.go
@@ -783,6 +783,13 @@
 		return err
 	}
 
+	if c.config.VerifyConnection != nil {
+		if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
+			c.sendAlert(alertBadCertificate)
+			return err
+		}
+	}
+
 	if len(certMsg.certificate.Certificate) != 0 {
 		msg, err = c.readHandshake()
 		if err != nil {