curve25519: use crypto/ecdh on Go 1.20

For golang/go#52221

Change-Id: I27e867d4cc89cd52c8d510f0dbab4e89b7cd4763
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/451115
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Cherry Mui <cherryyz@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
diff --git a/curve25519/curve25519.go b/curve25519/curve25519.go
index bc62161..00f963e 100644
--- a/curve25519/curve25519.go
+++ b/curve25519/curve25519.go
@@ -5,71 +5,18 @@
 // Package curve25519 provides an implementation of the X25519 function, which
 // performs scalar multiplication on the elliptic curve known as Curve25519.
 // See RFC 7748.
+//
+// Starting in Go 1.20, this package is a wrapper for the X25519 implementation
+// in the crypto/ecdh package.
 package curve25519 // import "golang.org/x/crypto/curve25519"
 
-import (
-	"crypto/subtle"
-	"errors"
-	"strconv"
-
-	"golang.org/x/crypto/curve25519/internal/field"
-)
-
 // ScalarMult sets dst to the product scalar * point.
 //
 // Deprecated: when provided a low-order point, ScalarMult will set dst to all
 // zeroes, irrespective of the scalar. Instead, use the X25519 function, which
 // will return an error.
 func ScalarMult(dst, scalar, point *[32]byte) {
-	var e [32]byte
-
-	copy(e[:], scalar[:])
-	e[0] &= 248
-	e[31] &= 127
-	e[31] |= 64
-
-	var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
-	x1.SetBytes(point[:])
-	x2.One()
-	x3.Set(&x1)
-	z3.One()
-
-	swap := 0
-	for pos := 254; pos >= 0; pos-- {
-		b := e[pos/8] >> uint(pos&7)
-		b &= 1
-		swap ^= int(b)
-		x2.Swap(&x3, swap)
-		z2.Swap(&z3, swap)
-		swap = int(b)
-
-		tmp0.Subtract(&x3, &z3)
-		tmp1.Subtract(&x2, &z2)
-		x2.Add(&x2, &z2)
-		z2.Add(&x3, &z3)
-		z3.Multiply(&tmp0, &x2)
-		z2.Multiply(&z2, &tmp1)
-		tmp0.Square(&tmp1)
-		tmp1.Square(&x2)
-		x3.Add(&z3, &z2)
-		z2.Subtract(&z3, &z2)
-		x2.Multiply(&tmp1, &tmp0)
-		tmp1.Subtract(&tmp1, &tmp0)
-		z2.Square(&z2)
-
-		z3.Mult32(&tmp1, 121666)
-		x3.Square(&x3)
-		tmp0.Add(&tmp0, &z3)
-		z3.Multiply(&x1, &z2)
-		z2.Multiply(&tmp1, &tmp0)
-	}
-
-	x2.Swap(&x3, swap)
-	z2.Swap(&z3, swap)
-
-	z2.Invert(&z2)
-	x2.Multiply(&x2, &z2)
-	copy(dst[:], x2.Bytes())
+	scalarMult(dst, scalar, point)
 }
 
 // ScalarBaseMult sets dst to the product scalar * base where base is the
@@ -78,7 +25,7 @@
 // It is recommended to use the X25519 function with Basepoint instead, as
 // copying into fixed size arrays can lead to unexpected bugs.
 func ScalarBaseMult(dst, scalar *[32]byte) {
-	ScalarMult(dst, scalar, &basePoint)
+	scalarBaseMult(dst, scalar)
 }
 
 const (
@@ -91,21 +38,10 @@
 // Basepoint is the canonical Curve25519 generator.
 var Basepoint []byte
 
-var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+var basePoint = [32]byte{9}
 
 func init() { Basepoint = basePoint[:] }
 
-func checkBasepoint() {
-	if subtle.ConstantTimeCompare(Basepoint, []byte{
-		0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-	}) != 1 {
-		panic("curve25519: global Basepoint value was modified")
-	}
-}
-
 // X25519 returns the result of the scalar multiplication (scalar * point),
 // according to RFC 7748, Section 5. scalar, point and the return value are
 // slices of 32 bytes.
@@ -121,26 +57,3 @@
 	var dst [32]byte
 	return x25519(&dst, scalar, point)
 }
-
-func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
-	var in [32]byte
-	if l := len(scalar); l != 32 {
-		return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
-	}
-	if l := len(point); l != 32 {
-		return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
-	}
-	copy(in[:], scalar)
-	if &point[0] == &Basepoint[0] {
-		checkBasepoint()
-		ScalarBaseMult(dst, &in)
-	} else {
-		var base, zero [32]byte
-		copy(base[:], point)
-		ScalarMult(dst, &in, &base)
-		if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
-			return nil, errors.New("bad input point: low order point")
-		}
-	}
-	return dst[:], nil
-}
diff --git a/curve25519/curve25519_compat.go b/curve25519/curve25519_compat.go
new file mode 100644
index 0000000..ba647e8
--- /dev/null
+++ b/curve25519/curve25519_compat.go
@@ -0,0 +1,105 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !go1.20
+
+package curve25519
+
+import (
+	"crypto/subtle"
+	"errors"
+	"strconv"
+
+	"golang.org/x/crypto/curve25519/internal/field"
+)
+
+func scalarMult(dst, scalar, point *[32]byte) {
+	var e [32]byte
+
+	copy(e[:], scalar[:])
+	e[0] &= 248
+	e[31] &= 127
+	e[31] |= 64
+
+	var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
+	x1.SetBytes(point[:])
+	x2.One()
+	x3.Set(&x1)
+	z3.One()
+
+	swap := 0
+	for pos := 254; pos >= 0; pos-- {
+		b := e[pos/8] >> uint(pos&7)
+		b &= 1
+		swap ^= int(b)
+		x2.Swap(&x3, swap)
+		z2.Swap(&z3, swap)
+		swap = int(b)
+
+		tmp0.Subtract(&x3, &z3)
+		tmp1.Subtract(&x2, &z2)
+		x2.Add(&x2, &z2)
+		z2.Add(&x3, &z3)
+		z3.Multiply(&tmp0, &x2)
+		z2.Multiply(&z2, &tmp1)
+		tmp0.Square(&tmp1)
+		tmp1.Square(&x2)
+		x3.Add(&z3, &z2)
+		z2.Subtract(&z3, &z2)
+		x2.Multiply(&tmp1, &tmp0)
+		tmp1.Subtract(&tmp1, &tmp0)
+		z2.Square(&z2)
+
+		z3.Mult32(&tmp1, 121666)
+		x3.Square(&x3)
+		tmp0.Add(&tmp0, &z3)
+		z3.Multiply(&x1, &z2)
+		z2.Multiply(&tmp1, &tmp0)
+	}
+
+	x2.Swap(&x3, swap)
+	z2.Swap(&z3, swap)
+
+	z2.Invert(&z2)
+	x2.Multiply(&x2, &z2)
+	copy(dst[:], x2.Bytes())
+}
+
+func scalarBaseMult(dst, scalar *[32]byte) {
+	checkBasepoint()
+	scalarMult(dst, scalar, &basePoint)
+}
+
+func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
+	var in [32]byte
+	if l := len(scalar); l != 32 {
+		return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32")
+	}
+	if l := len(point); l != 32 {
+		return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32")
+	}
+	copy(in[:], scalar)
+	if &point[0] == &Basepoint[0] {
+		scalarBaseMult(dst, &in)
+	} else {
+		var base, zero [32]byte
+		copy(base[:], point)
+		scalarMult(dst, &in, &base)
+		if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 {
+			return nil, errors.New("bad input point: low order point")
+		}
+	}
+	return dst[:], nil
+}
+
+func checkBasepoint() {
+	if subtle.ConstantTimeCompare(Basepoint, []byte{
+		0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+	}) != 1 {
+		panic("curve25519: global Basepoint value was modified")
+	}
+}
diff --git a/curve25519/curve25519_go120.go b/curve25519/curve25519_go120.go
new file mode 100644
index 0000000..627df49
--- /dev/null
+++ b/curve25519/curve25519_go120.go
@@ -0,0 +1,46 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.20
+
+package curve25519
+
+import "crypto/ecdh"
+
+func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) {
+	curve := ecdh.X25519()
+	pub, err := curve.NewPublicKey(point)
+	if err != nil {
+		return nil, err
+	}
+	priv, err := curve.NewPrivateKey(scalar)
+	if err != nil {
+		return nil, err
+	}
+	out, err := priv.ECDH(pub)
+	if err != nil {
+		return nil, err
+	}
+	copy(dst[:], out)
+	return dst[:], nil
+}
+
+func scalarMult(dst, scalar, point *[32]byte) {
+	if _, err := x25519(dst, scalar[:], point[:]); err != nil {
+		// The only error condition for x25519 when the inputs are 32 bytes long
+		// is if the output would have been the all-zero value.
+		for i := range dst {
+			dst[i] = 0
+		}
+	}
+}
+
+func scalarBaseMult(dst, scalar *[32]byte) {
+	curve := ecdh.X25519()
+	priv, err := curve.NewPrivateKey(scalar[:])
+	if err != nil {
+		panic("curve25519: internal error: scalarBaseMult was not 32 bytes")
+	}
+	copy(dst[:], priv.PublicKey().Bytes())
+}
diff --git a/curve25519/curve25519_test.go b/curve25519/curve25519_test.go
index 5a31541..e2b338b 100644
--- a/curve25519/curve25519_test.go
+++ b/curve25519/curve25519_test.go
@@ -2,13 +2,15 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package curve25519
+package curve25519_test
 
 import (
 	"bytes"
 	"crypto/rand"
 	"encoding/hex"
 	"testing"
+
+	"golang.org/x/crypto/curve25519"
 )
 
 const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
@@ -19,7 +21,7 @@
 
 	for i := 0; i < 200; i++ {
 		var err error
-		x, err = X25519(x, Basepoint)
+		x, err = curve25519.X25519(x, curve25519.Basepoint)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -32,12 +34,12 @@
 }
 
 func TestLowOrderPoints(t *testing.T) {
-	scalar := make([]byte, ScalarSize)
+	scalar := make([]byte, curve25519.ScalarSize)
 	if _, err := rand.Read(scalar); err != nil {
 		t.Fatal(err)
 	}
 	for i, p := range lowOrderPoints {
-		out, err := X25519(scalar, p)
+		out, err := curve25519.X25519(scalar, p)
 		if err == nil {
 			t.Errorf("%d: expected error, got nil", i)
 		}
@@ -48,10 +50,10 @@
 }
 
 func TestTestVectors(t *testing.T) {
-	t.Run("Legacy", func(t *testing.T) { testTestVectors(t, ScalarMult) })
+	t.Run("Legacy", func(t *testing.T) { testTestVectors(t, curve25519.ScalarMult) })
 	t.Run("X25519", func(t *testing.T) {
 		testTestVectors(t, func(dst, scalar, point *[32]byte) {
-			out, err := X25519(scalar[:], point[:])
+			out, err := curve25519.X25519(scalar[:], point[:])
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -88,10 +90,10 @@
 	var hi0, hi1 [32]byte
 
 	u[31] &= 0x7f
-	ScalarMult(&hi0, &s, &u)
+	curve25519.ScalarMult(&hi0, &s, &u)
 
 	u[31] |= 0x80
-	ScalarMult(&hi1, &s, &u)
+	curve25519.ScalarMult(&hi1, &s, &u)
 
 	if !bytes.Equal(hi0[:], hi1[:]) {
 		t.Errorf("high bit of group point should not affect result")
@@ -101,14 +103,14 @@
 var benchmarkSink byte
 
 func BenchmarkX25519Basepoint(b *testing.B) {
-	scalar := make([]byte, ScalarSize)
+	scalar := make([]byte, curve25519.ScalarSize)
 	if _, err := rand.Read(scalar); err != nil {
 		b.Fatal(err)
 	}
 
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
-		out, err := X25519(scalar, Basepoint)
+		out, err := curve25519.X25519(scalar, curve25519.Basepoint)
 		if err != nil {
 			b.Fatal(err)
 		}
@@ -117,11 +119,11 @@
 }
 
 func BenchmarkX25519(b *testing.B) {
-	scalar := make([]byte, ScalarSize)
+	scalar := make([]byte, curve25519.ScalarSize)
 	if _, err := rand.Read(scalar); err != nil {
 		b.Fatal(err)
 	}
-	point, err := X25519(scalar, Basepoint)
+	point, err := curve25519.X25519(scalar, curve25519.Basepoint)
 	if err != nil {
 		b.Fatal(err)
 	}
@@ -131,7 +133,7 @@
 
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
-		out, err := X25519(scalar, point)
+		out, err := curve25519.X25519(scalar, point)
 		if err != nil {
 			b.Fatal(err)
 		}
diff --git a/curve25519/vectors_test.go b/curve25519/vectors_test.go
index 946e9a8..f4c0a14 100644
--- a/curve25519/vectors_test.go
+++ b/curve25519/vectors_test.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package curve25519
+package curve25519_test
 
 // lowOrderPoints from libsodium.
 // https://github.com/jedisct1/libsodium/blob/65621a1059a37d/src/libsodium/crypto_scalarmult/curve25519/ref10/x25519_ref10.c#L11-L70