chacha20poly1305: add runtime internal independent cpu feature detection

Change-Id: I150c5e0453b0fa3457d4786fe90901a54e216b02
Reviewed-on: https://go-review.googlesource.com/41862
Run-TryBot: Martin Möhrmann <moehrmann@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
diff --git a/chacha20poly1305/chacha20poly1305_amd64.go b/chacha20poly1305/chacha20poly1305_amd64.go
index 4755033..1e523b9 100644
--- a/chacha20poly1305/chacha20poly1305_amd64.go
+++ b/chacha20poly1305/chacha20poly1305_amd64.go
@@ -14,13 +14,60 @@
 //go:noescape
 func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
 
-//go:noescape
-func haveSSSE3() bool
+// cpuid is implemented in chacha20poly1305_amd64.s.
+func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
 
-var canUseASM bool
+// xgetbv with ecx = 0 is implemented in chacha20poly1305_amd64.s.
+func xgetbv() (eax, edx uint32)
+
+var (
+	useASM  bool
+	useAVX2 bool
+)
 
 func init() {
-	canUseASM = haveSSSE3()
+	detectCpuFeatures()
+}
+
+// detectCpuFeatures is used to detect if cpu instructions
+// used by the functions implemented in assembler in
+// chacha20poly1305_amd64.s are supported.
+func detectCpuFeatures() {
+	maxId, _, _, _ := cpuid(0, 0)
+	if maxId < 1 {
+		return
+	}
+
+	_, _, ecx1, _ := cpuid(1, 0)
+
+	haveSSSE3 := isSet(9, ecx1)
+	useASM = haveSSSE3
+
+	haveOSXSAVE := isSet(27, ecx1)
+
+	osSupportsAVX := false
+	// For XGETBV, OSXSAVE bit is required and sufficient.
+	if haveOSXSAVE {
+		eax, _ := xgetbv()
+		// Check if XMM and YMM registers have OS support.
+		osSupportsAVX = isSet(1, eax) && isSet(2, eax)
+	}
+	haveAVX := isSet(28, ecx1) && osSupportsAVX
+
+	if maxId < 7 {
+		return
+	}
+
+	_, ebx7, _, _ := cpuid(7, 0)
+	haveAVX2 := isSet(5, ebx7) && haveAVX
+	haveBMI2 := isSet(8, ebx7)
+
+	useAVX2 = haveAVX2 && haveBMI2
+}
+
+// isSet checks if bit at bitpos is set in value.
+func isSet(bitpos uint, value uint32) bool {
+	return value&(1<<bitpos) != 0
 }
 
 // setupState writes a ChaCha20 input matrix to state. See
@@ -47,7 +94,7 @@
 }
 
 func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
-	if !canUseASM {
+	if !useASM {
 		return c.sealGeneric(dst, nonce, plaintext, additionalData)
 	}
 
@@ -60,7 +107,7 @@
 }
 
 func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
-	if !canUseASM {
+	if !useASM {
 		return c.openGeneric(dst, nonce, ciphertext, additionalData)
 	}
 
diff --git a/chacha20poly1305/chacha20poly1305_amd64.s b/chacha20poly1305/chacha20poly1305_amd64.s
index 39c58b4..1c57e38 100644
--- a/chacha20poly1305/chacha20poly1305_amd64.s
+++ b/chacha20poly1305/chacha20poly1305_amd64.s
@@ -278,15 +278,8 @@
 	MOVQ ad+72(FP), adp
 
 	// Check for AVX2 support
-	CMPB runtime·support_avx2(SB), $0
-	JE   noavx2bmi2Open
-
-	// Check BMI2 bit for MULXQ.
-	// runtime·cpuid_ebx7 is always available here
-	// because it passed avx2 check
-	TESTL $(1<<8), runtime·cpuid_ebx7(SB)
-	JNE   chacha20Poly1305Open_AVX2
-noavx2bmi2Open:
+	CMPB ·useAVX2(SB), $1
+	JE   chacha20Poly1305Open_AVX2
 
 	// Special optimization, for very short buffers
 	CMPQ inl, $128
@@ -1491,16 +1484,8 @@
 	MOVQ src_len+56(FP), inl
 	MOVQ ad+72(FP), adp
 
-	// Check for AVX2 support
-	CMPB runtime·support_avx2(SB), $0
-	JE   noavx2bmi2Seal
-
-	// Check BMI2 bit for MULXQ.
-	// runtime·cpuid_ebx7 is always available here
-	// because it passed avx2 check
-	TESTL $(1<<8), runtime·cpuid_ebx7(SB)
-	JNE   chacha20Poly1305Seal_AVX2
-noavx2bmi2Seal:
+	CMPB ·useAVX2(SB), $1
+	JE   chacha20Poly1305Seal_AVX2
 
 	// Special optimization, for very short buffers
 	CMPQ inl, $128
@@ -2709,13 +2694,21 @@
 
 	JMP sealAVX2SealHash
 
-// func haveSSSE3() bool
-TEXT ·haveSSSE3(SB), NOSPLIT, $0
-	XORQ AX, AX
-	INCL AX
+// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
+TEXT ·cpuid(SB), NOSPLIT, $0-24
+	MOVL eaxArg+0(FP), AX
+	MOVL ecxArg+4(FP), CX
 	CPUID
-	SHRQ $9, CX
-	ANDQ $1, CX
-	MOVB CX, ret+0(FP)
+	MOVL AX, eax+8(FP)
+	MOVL BX, ebx+12(FP)
+	MOVL CX, ecx+16(FP)
+	MOVL DX, edx+20(FP)
 	RET
 
+// func xgetbv() (eax, edx uint32)
+TEXT ·xgetbv(SB),NOSPLIT,$0-8
+	MOVL $0, CX
+	XGETBV
+	MOVL AX, eax+0(FP)
+	MOVL DX, edx+4(FP)
+	RET