crypto/sha512: add WriteString and WriteByte method

This can reduce allocations when hashing a string or byte
rather than []byte.

For #38776

Change-Id: I4926ae2749f6b167edbebb73d8f68763ffb2f0c1
Reviewed-on: https://go-review.googlesource.com/c/go/+/483816
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Joel Sing <joel@sing.id.au>
Auto-Submit: Ian Lance Taylor <iant@google.com>
diff --git a/src/crypto/internal/boring/sha.go b/src/crypto/internal/boring/sha.go
index c9772aa..702c687 100644
--- a/src/crypto/internal/boring/sha.go
+++ b/src/crypto/internal/boring/sha.go
@@ -428,6 +428,20 @@
 	return len(p), nil
 }
 
+func (h *sha384Hash) WriteString(s string) (int, error) {
+	if len(s) > 0 && C._goboringcrypto_SHA384_Update(h.noescapeCtx(), unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 {
+		panic("boringcrypto: SHA384_Update failed")
+	}
+	return len(s), nil
+}
+
+func (h *sha384Hash) WriteByte(c byte) error {
+	if C._goboringcrypto_SHA384_Update(h.noescapeCtx(), unsafe.Pointer(&c), 1) == 0 {
+		panic("boringcrypto: SHA384_Update failed")
+	}
+	return nil
+}
+
 func (h0 *sha384Hash) sum(dst []byte) []byte {
 	h := *h0 // make copy so future Write+Sum is valid
 	if C._goboringcrypto_SHA384_Final((*C.uint8_t)(noescape(unsafe.Pointer(&h.out[0]))), h.noescapeCtx()) == 0 {
@@ -466,6 +480,20 @@
 	return len(p), nil
 }
 
+func (h *sha512Hash) WriteString(s string) (int, error) {
+	if len(s) > 0 && C._goboringcrypto_SHA512_Update(h.noescapeCtx(), unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 {
+		panic("boringcrypto: SHA512_Update failed")
+	}
+	return len(s), nil
+}
+
+func (h *sha512Hash) WriteByte(c byte) error {
+	if C._goboringcrypto_SHA512_Update(h.noescapeCtx(), unsafe.Pointer(&c), 1) == 0 {
+		panic("boringcrypto: SHA512_Update failed")
+	}
+	return nil
+}
+
 func (h0 *sha512Hash) sum(dst []byte) []byte {
 	h := *h0 // make copy so future Write+Sum is valid
 	if C._goboringcrypto_SHA512_Final((*C.uint8_t)(noescape(unsafe.Pointer(&h.out[0]))), h.noescapeCtx()) == 0 {
diff --git a/src/crypto/sha512/sha512.go b/src/crypto/sha512/sha512.go
index 9ae1b3a..b22c50b 100644
--- a/src/crypto/sha512/sha512.go
+++ b/src/crypto/sha512/sha512.go
@@ -254,20 +254,10 @@
 func (d *digest) BlockSize() int { return BlockSize }
 
 func (d *digest) Write(p []byte) (nn int, err error) {
-	if d.function != crypto.SHA512_224 && d.function != crypto.SHA512_256 {
-		boring.Unreachable()
-	}
 	nn = len(p)
 	d.len += uint64(nn)
-	if d.nx > 0 {
-		n := copy(d.x[d.nx:], p)
-		d.nx += n
-		if d.nx == chunk {
-			block(d, d.x[:])
-			d.nx = 0
-		}
-		p = p[n:]
-	}
+	n := fillChunk(d, p)
+	p = p[n:]
 	if len(p) >= chunk {
 		n := len(p) &^ (chunk - 1)
 		block(d, p[:n])
@@ -279,6 +269,59 @@
 	return
 }
 
+func (d *digest) WriteString(s string) (nn int, err error) {
+	nn = len(s)
+	d.len += uint64(nn)
+	n := fillChunk(d, s)
+
+	// This duplicates the code in Write, except that it calls
+	// blockString rather than block. It would be nicer to pass
+	// in a func, but as of this writing (Go 1.20) that causes
+	// memory allocations that we want to avoid.
+
+	s = s[n:]
+	if len(s) >= chunk {
+		n := len(s) &^ (chunk - 1)
+		blockString(d, s[:n])
+		s = s[n:]
+	}
+	if len(s) > 0 {
+		d.nx = copy(d.x[:], s)
+	}
+	return
+}
+
+// fillChunk fills the remainder of the current chunk, if any.
+func fillChunk[S []byte | string](d *digest, p S) int {
+	if d.function != crypto.SHA512_224 && d.function != crypto.SHA512_256 {
+		boring.Unreachable()
+	}
+	if d.nx == 0 {
+		return 0
+	}
+	n := copy(d.x[d.nx:], p)
+	d.nx += n
+	if d.nx == chunk {
+		block(d, d.x[:])
+		d.nx = 0
+	}
+	return n
+}
+
+func (d *digest) WriteByte(c byte) error {
+	if d.function != crypto.SHA512_224 && d.function != crypto.SHA512_256 {
+		boring.Unreachable()
+	}
+	d.len++
+	d.x[d.nx] = c
+	d.nx++
+	if d.nx == chunk {
+		block(d, d.x[:])
+		d.nx = 0
+	}
+	return nil
+}
+
 func (d *digest) Sum(in []byte) []byte {
 	if d.function != crypto.SHA512_224 && d.function != crypto.SHA512_256 {
 		boring.Unreachable()
diff --git a/src/crypto/sha512/sha512_test.go b/src/crypto/sha512/sha512_test.go
index 921cdbb..cbe195e3 100644
--- a/src/crypto/sha512/sha512_test.go
+++ b/src/crypto/sha512/sha512_test.go
@@ -676,6 +676,15 @@
 		}
 		digestFunc.Reset()
 	}
+
+	bw := digestFunc.(io.ByteWriter)
+	for i := 0; i < len(in); i++ {
+		bw.WriteByte(in[i])
+	}
+	if calculated := hex.EncodeToString(digestFunc.Sum(nil)); calculated != outHex {
+		t.Errorf("%s(%q) = %q using WriteByte but expected %q", name, in, calculated, outHex)
+	}
+	digestFunc.Reset()
 }
 
 func TestGolden(t *testing.T) {
@@ -896,7 +905,8 @@
 	if boring.Enabled {
 		t.Skip("BoringCrypto doesn't allocate the same way as stdlib")
 	}
-	in := []byte("hello, world!")
+	const ins = "hello, world!"
+	in := []byte(ins)
 	out := make([]byte, 0, Size)
 	h := New()
 	n := int(testing.AllocsPerRun(10, func() {
@@ -907,6 +917,28 @@
 	if n > 0 {
 		t.Errorf("allocs = %d, want 0", n)
 	}
+
+	sw := h.(io.StringWriter)
+	n = int(testing.AllocsPerRun(10, func() {
+		h.Reset()
+		sw.WriteString(ins)
+		out = h.Sum(out[:0])
+	}))
+	if n > 0 {
+		t.Errorf("string allocs = %d, want 0", n)
+	}
+
+	bw := h.(io.ByteWriter)
+	n = int(testing.AllocsPerRun(10, func() {
+		h.Reset()
+		for _, b := range in {
+			bw.WriteByte(b)
+		}
+		out = h.Sum(out[:0])
+	}))
+	if n > 0 {
+		t.Errorf("byte allocs = %d, want 0", n)
+	}
 }
 
 var bench = New()
diff --git a/src/crypto/sha512/sha512block.go b/src/crypto/sha512/sha512block.go
index 81569c5..b0dcf27 100644
--- a/src/crypto/sha512/sha512block.go
+++ b/src/crypto/sha512/sha512block.go
@@ -93,7 +93,7 @@
 	0x6c44198c4a475817,
 }
 
-func blockGeneric(dig *digest, p []byte) {
+func blockGeneric[S []byte | string](dig *digest, p S) {
 	var w [80]uint64
 	h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7]
 	for len(p) >= chunk {
diff --git a/src/crypto/sha512/sha512block_amd64.go b/src/crypto/sha512/sha512block_amd64.go
index 8da3e14..4d9ec5a 100644
--- a/src/crypto/sha512/sha512block_amd64.go
+++ b/src/crypto/sha512/sha512block_amd64.go
@@ -6,20 +6,31 @@
 
 package sha512
 
-import "internal/cpu"
+import (
+	"internal/cpu"
+	"unsafe"
+)
 
 //go:noescape
-func blockAVX2(dig *digest, p []byte)
+func blockAVX2(dig *digest, p *byte, n int)
 
 //go:noescape
-func blockAMD64(dig *digest, p []byte)
+func blockAMD64(dig *digest, p *byte, n int)
 
 var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI1 && cpu.X86.HasBMI2
 
 func block(dig *digest, p []byte) {
 	if useAVX2 {
-		blockAVX2(dig, p)
+		blockAVX2(dig, unsafe.SliceData(p), len(p))
 	} else {
-		blockAMD64(dig, p)
+		blockAMD64(dig, unsafe.SliceData(p), len(p))
+	}
+}
+
+func blockString(dig *digest, s string) {
+	if useAVX2 {
+		blockAVX2(dig, unsafe.StringData(s), len(s))
+	} else {
+		blockAMD64(dig, unsafe.StringData(s), len(s))
 	}
 }
diff --git a/src/crypto/sha512/sha512block_amd64.s b/src/crypto/sha512/sha512block_amd64.s
index 0fa0df2..e8a89e3 100644
--- a/src/crypto/sha512/sha512block_amd64.s
+++ b/src/crypto/sha512/sha512block_amd64.s
@@ -141,9 +141,9 @@
 	MSGSCHEDULE1(index); \
 	SHA512ROUND(index, const, a, b, c, d, e, f, g, h)
 
-TEXT ·blockAMD64(SB),0,$648-32
-	MOVQ	p_base+8(FP), SI
-	MOVQ	p_len+16(FP), DX
+TEXT ·blockAMD64(SB),0,$648-24
+	MOVQ	p+8(FP), SI
+	MOVQ	n+16(FP), DX
 	SHRQ	$7, DX
 	SHLQ	$7, DX
 
@@ -319,10 +319,10 @@
 
 GLOBL MASK_YMM_LO<>(SB), (NOPTR+RODATA), $32
 
-TEXT ·blockAVX2(SB), NOSPLIT, $56-32
+TEXT ·blockAVX2(SB), NOSPLIT, $56-24
 	MOVQ dig+0(FP), SI
-	MOVQ p_base+8(FP), DI
-	MOVQ p_len+16(FP), DX
+	MOVQ p+8(FP), DI
+	MOVQ n+16(FP), DX
 
 	SHRQ $7, DX
 	SHLQ $7, DX
diff --git a/src/crypto/sha512/sha512block_arm64.go b/src/crypto/sha512/sha512block_arm64.go
index 243eb5c..a916a0a 100644
--- a/src/crypto/sha512/sha512block_arm64.go
+++ b/src/crypto/sha512/sha512block_arm64.go
@@ -4,15 +4,26 @@
 
 package sha512
 
-import "internal/cpu"
+import (
+	"internal/cpu"
+	"unsafe"
+)
 
 func block(dig *digest, p []byte) {
 	if cpu.ARM64.HasSHA512 {
-		blockAsm(dig, p)
+		blockAsm(dig, unsafe.SliceData(p), len(p))
 		return
 	}
 	blockGeneric(dig, p)
 }
 
+func blockString(dig *digest, s string) {
+	if cpu.ARM64.HasSHA512 {
+		blockAsm(dig, unsafe.StringData(s), len(s))
+		return
+	}
+	blockGeneric(dig, s)
+}
+
 //go:noescape
-func blockAsm(dig *digest, p []byte)
+func blockAsm(dig *digest, p *byte, n int)
diff --git a/src/crypto/sha512/sha512block_arm64.s b/src/crypto/sha512/sha512block_arm64.s
index dfc35d6..647ee62 100644
--- a/src/crypto/sha512/sha512block_arm64.s
+++ b/src/crypto/sha512/sha512block_arm64.s
@@ -38,11 +38,11 @@
 	VADD	i3.D2, i1.D2, i4.D2 \
 	SHA512H2	i0.D2, i1, i3
 
-// func blockAsm(dig *digest, p []byte)
+// func blockAsm(dig *digest, p *byte, n int)
 TEXT ·blockAsm(SB),NOSPLIT,$0
 	MOVD	dig+0(FP), R0
-	MOVD	p_base+8(FP), R1
-	MOVD	p_len+16(FP), R2
+	MOVD	p+8(FP), R1
+	MOVD	n+16(FP), R2
 	MOVD	·_K+0(SB), R3
 
 	// long enough to prefetch
diff --git a/src/crypto/sha512/sha512block_decl.go b/src/crypto/sha512/sha512block_decl.go
index 4ad4418..399f13c 100644
--- a/src/crypto/sha512/sha512block_decl.go
+++ b/src/crypto/sha512/sha512block_decl.go
@@ -6,5 +6,15 @@
 
 package sha512
 
+import "unsafe"
+
 //go:noescape
-func block(dig *digest, p []byte)
+func doBlock(dig *digest, p *byte, n int)
+
+func block(dig *digest, p []byte) {
+	doBlock(dig, unsafe.SliceData(p), len(p))
+}
+
+func blockString(dig *digest, s string) {
+	doBlock(dig, unsafe.StringData(s), len(s))
+}
diff --git a/src/crypto/sha512/sha512block_generic.go b/src/crypto/sha512/sha512block_generic.go
index 02ecc2c..116d6c8 100644
--- a/src/crypto/sha512/sha512block_generic.go
+++ b/src/crypto/sha512/sha512block_generic.go
@@ -9,3 +9,7 @@
 func block(dig *digest, p []byte) {
 	blockGeneric(dig, p)
 }
+
+func blockString(dig *digest, s string) {
+	blockGeneric(dig, s)
+}
diff --git a/src/crypto/sha512/sha512block_ppc64x.s b/src/crypto/sha512/sha512block_ppc64x.s
index 90dbf0f..df9a7bb 100644
--- a/src/crypto/sha512/sha512block_ppc64x.s
+++ b/src/crypto/sha512/sha512block_ppc64x.s
@@ -304,11 +304,11 @@
 	VADDUDM		S0, h, h; \
 	VADDUDM		s1, xj, xj
 
-// func block(dig *digest, p []byte)
-TEXT ·block(SB),0,$0-32
+// func doBlock(dig *digest, p *byte, b int)
+TEXT ·doBlock(SB),0,$0-24
 	MOVD	dig+0(FP), CTX
-	MOVD	p_base+8(FP), INP
-	MOVD	p_len+16(FP), LEN
+	MOVD	p+8(FP), INP
+	MOVD	n+16(FP), LEN
 
 	SRD	$6, LEN
 	SLD	$6, LEN
diff --git a/src/crypto/sha512/sha512block_s390x.go b/src/crypto/sha512/sha512block_s390x.go
index 7df29fd..d7412ee 100644
--- a/src/crypto/sha512/sha512block_s390x.go
+++ b/src/crypto/sha512/sha512block_s390x.go
@@ -4,6 +4,13 @@
 
 package sha512
 
-import "internal/cpu"
+import (
+	"internal/cpu"
+	"unsafe"
+)
 
 var useAsm = cpu.S390X.HasSHA512
+
+func doBlockGeneric(dig *digest, p *byte, n int) {
+	blockGeneric(dig, unsafe.String(p, n))
+}
diff --git a/src/crypto/sha512/sha512block_s390x.s b/src/crypto/sha512/sha512block_s390x.s
index f221bd1..3879bf8 100644
--- a/src/crypto/sha512/sha512block_s390x.s
+++ b/src/crypto/sha512/sha512block_s390x.s
@@ -4,8 +4,8 @@
 
 #include "textflag.h"
 
-// func block(dig *digest, p []byte)
-TEXT ·block(SB), NOSPLIT|NOFRAME, $0-32
+// func doBlock(dig *digest, p *byte, n int)
+TEXT ·doBlock(SB), NOSPLIT|NOFRAME, $0-24
 	MOVBZ  ·useAsm(SB), R4
 	LMG    dig+0(FP), R1, R3            // R2 = &p[0], R3 = len(p)
 	MOVBZ  $3, R0                       // SHA-512 function code
@@ -17,4 +17,4 @@
 	RET
 
 generic:
-	BR ·blockGeneric(SB)
+	BR ·doBlockGeneric(SB)