crypto/x509: return additional chains from Verify on Windows

Previously windows only returned the certificate-chain with the highest quality.
This change makes it so chains with a potentially lower quality
originating from other root certificates are also returned by verify.

Tests in verify_test flagged with systemLax are now allowed to pass if the system returns additional chains

Fixes #40604

Change-Id: I66edc233219f581039d47a15f2200ff627154691
Reviewed-on: https://go-review.googlesource.com/c/go/+/257257
Reviewed-by: Tobias Klauser <tobias.klauser@gmail.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Trust: Tobias Klauser <tobias.klauser@gmail.com>
Run-TryBot: Tobias Klauser <tobias.klauser@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
diff --git a/src/crypto/x509/root_windows.go b/src/crypto/x509/root_windows.go
index 22e5a93..1e9be80 100644
--- a/src/crypto/x509/root_windows.go
+++ b/src/crypto/x509/root_windows.go
@@ -155,6 +155,44 @@
 	}
 }
 
+func verifyChain(c *Certificate, chainCtx *syscall.CertChainContext, opts *VerifyOptions) (chain []*Certificate, err error) {
+	err = checkChainTrustStatus(c, chainCtx)
+	if err != nil {
+		return nil, err
+	}
+
+	if opts != nil && len(opts.DNSName) > 0 {
+		err = checkChainSSLServerPolicy(c, chainCtx, opts)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	chain, err = extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
+	if err != nil {
+		return nil, err
+	}
+	if len(chain) == 0 {
+		return nil, errors.New("x509: internal error: system verifier returned an empty chain")
+	}
+
+	// Mitigate CVE-2020-0601, where the Windows system verifier might be
+	// tricked into using custom curve parameters for a trusted root, by
+	// double-checking all ECDSA signatures. If the system was tricked into
+	// using spoofed parameters, the signature will be invalid for the correct
+	// ones we parsed. (We don't support custom curves ourselves.)
+	for i, parent := range chain[1:] {
+		if parent.PublicKeyAlgorithm != ECDSA {
+			continue
+		}
+		if err := parent.CheckSignature(chain[i].SignatureAlgorithm,
+			chain[i].RawTBSCertificate, chain[i].Signature); err != nil {
+			return nil, err
+		}
+	}
+	return chain, nil
+}
+
 // systemVerify is like Verify, except that it uses CryptoAPI calls
 // to build certificate chains and verify them.
 func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
@@ -202,67 +240,41 @@
 		verifyTime = &ft
 	}
 
-	// CertGetCertificateChain will traverse Windows's root stores
-	// in an attempt to build a verified certificate chain. Once
-	// it has found a verified chain, it stops. MSDN docs on
-	// CERT_CHAIN_CONTEXT:
-	//
-	//   When a CERT_CHAIN_CONTEXT is built, the first simple chain
-	//   begins with an end certificate and ends with a self-signed
-	//   certificate. If that self-signed certificate is not a root
-	//   or otherwise trusted certificate, an attempt is made to
-	//   build a new chain. CTLs are used to create the new chain
-	//   beginning with the self-signed certificate from the original
-	//   chain as the end certificate of the new chain. This process
-	//   continues building additional simple chains until the first
-	//   self-signed certificate is a trusted certificate or until
-	//   an additional simple chain cannot be built.
-	//
-	// The result is that we'll only get a single trusted chain to
-	// return to our caller.
-	var chainCtx *syscall.CertChainContext
-	err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, 0, 0, &chainCtx)
+	// The default is to return only the highest quality chain,
+	// setting this flag will add additional lower quality contexts.
+	// These are returned in the LowerQualityChains field.
+	const CERT_CHAIN_RETURN_LOWER_QUALITY_CONTEXTS = 0x00000080
+
+	// CertGetCertificateChain will traverse Windows's root stores in an attempt to build a verified certificate chain
+	var topCtx *syscall.CertChainContext
+	err = syscall.CertGetCertificateChain(syscall.Handle(0), storeCtx, verifyTime, storeCtx.Store, para, CERT_CHAIN_RETURN_LOWER_QUALITY_CONTEXTS, 0, &topCtx)
 	if err != nil {
 		return nil, err
 	}
-	defer syscall.CertFreeCertificateChain(chainCtx)
+	defer syscall.CertFreeCertificateChain(topCtx)
 
-	err = checkChainTrustStatus(c, chainCtx)
-	if err != nil {
-		return nil, err
+	chain, topErr := verifyChain(c, topCtx, opts)
+	if topErr == nil {
+		chains = append(chains, chain)
 	}
 
-	if opts != nil && len(opts.DNSName) > 0 {
-		err = checkChainSSLServerPolicy(c, chainCtx, opts)
-		if err != nil {
-			return nil, err
+	if lqCtxCount := topCtx.LowerQualityChainCount; lqCtxCount > 0 {
+		lqCtxs := (*[1 << 20]*syscall.CertChainContext)(unsafe.Pointer(topCtx.LowerQualityChains))[:lqCtxCount:lqCtxCount]
+
+		for _, ctx := range lqCtxs {
+			chain, err := verifyChain(c, ctx, opts)
+			if err == nil {
+				chains = append(chains, chain)
+			}
 		}
 	}
 
-	chain, err := extractSimpleChain(chainCtx.Chains, int(chainCtx.ChainCount))
-	if err != nil {
-		return nil, err
-	}
-	if len(chain) < 1 {
-		return nil, errors.New("x509: internal error: system verifier returned an empty chain")
+	if len(chains) == 0 {
+		// Return the error from the highest quality context.
+		return nil, topErr
 	}
 
-	// Mitigate CVE-2020-0601, where the Windows system verifier might be
-	// tricked into using custom curve parameters for a trusted root, by
-	// double-checking all ECDSA signatures. If the system was tricked into
-	// using spoofed parameters, the signature will be invalid for the correct
-	// ones we parsed. (We don't support custom curves ourselves.)
-	for i, parent := range chain[1:] {
-		if parent.PublicKeyAlgorithm != ECDSA {
-			continue
-		}
-		if err := parent.CheckSignature(chain[i].SignatureAlgorithm,
-			chain[i].RawTBSCertificate, chain[i].Signature); err != nil {
-			return nil, err
-		}
-	}
-
-	return [][]*Certificate{chain}, nil
+	return chains, nil
 }
 
 func loadSystemRoots() (*CertPool, error) {
diff --git a/src/crypto/x509/verify_test.go b/src/crypto/x509/verify_test.go
index 9cc17c7..8e0a7be 100644
--- a/src/crypto/x509/verify_test.go
+++ b/src/crypto/x509/verify_test.go
@@ -550,34 +550,55 @@
 		}
 	}
 
-	if len(chains) != len(test.expectedChains) {
-		t.Errorf("wanted %d chains, got %d", len(test.expectedChains), len(chains))
+	doesMatch := func(expectedChain []string, chain []*Certificate) bool {
+		if len(chain) != len(expectedChain) {
+			return false
+		}
+
+		for k, cert := range chain {
+			if !strings.Contains(nameToKey(&cert.Subject), expectedChain[k]) {
+				return false
+			}
+		}
+		return true
 	}
 
-	// We check that each returned chain matches a chain from
-	// expectedChains but an entry in expectedChains can't match
-	// two chains.
-	seenChains := make([]bool, len(chains))
-NextOutputChain:
-	for _, chain := range chains {
-	TryNextExpected:
-		for j, expectedChain := range test.expectedChains {
-			if seenChains[j] {
-				continue
+	// Every expected chain should match 1 returned chain
+	for _, expectedChain := range test.expectedChains {
+		nChainMatched := 0
+		for _, chain := range chains {
+			if doesMatch(expectedChain, chain) {
+				nChainMatched++
 			}
-			if len(chain) != len(expectedChain) {
-				continue
-			}
-			for k, cert := range chain {
-				if !strings.Contains(nameToKey(&cert.Subject), expectedChain[k]) {
-					continue TryNextExpected
+		}
+
+		if nChainMatched != 1 {
+			t.Errorf("Got %v matches instead of %v for expected chain %v", nChainMatched, 1, expectedChain)
+			for _, chain := range chains {
+				if doesMatch(expectedChain, chain) {
+					t.Errorf("\t matched %v", chainToDebugString(chain))
 				}
 			}
-			// we matched
-			seenChains[j] = true
-			continue NextOutputChain
 		}
-		t.Errorf("no expected chain matched %s", chainToDebugString(chain))
+	}
+
+	// Every returned chain should match 1 expected chain (or <2 if testing against the system)
+	for _, chain := range chains {
+		nMatched := 0
+		for _, expectedChain := range test.expectedChains {
+			if doesMatch(expectedChain, chain) {
+				nMatched++
+			}
+		}
+		// Allow additional unknown chains if systemLax is set
+		if nMatched == 0 && test.systemLax == false || nMatched > 1 {
+			t.Errorf("Got %v matches for chain %v", nMatched, chainToDebugString(chain))
+			for _, expectedChain := range test.expectedChains {
+				if doesMatch(expectedChain, chain) {
+					t.Errorf("\t matched %v", expectedChain)
+				}
+			}
+		}
 	}
 }