x/crypto/ssh: implement curve25519-sha256@libssh.org key agreement.

Fixes golang/go#11004.

Change-Id: Ic37cf9d620e3397b7ad769ae16abdaee63a7733b
Reviewed-on: https://go-review.googlesource.com/13592
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/kex.go b/ssh/kex.go
index 6a835c7..ea19d53 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -8,18 +8,22 @@
 	"crypto"
 	"crypto/ecdsa"
 	"crypto/elliptic"
+	"crypto/subtle"
 	"crypto/rand"
 	"errors"
 	"io"
 	"math/big"
+
+	"golang.org/x/crypto/curve25519"
 )
 
 const (
-	kexAlgoDH1SHA1  = "diffie-hellman-group1-sha1"
-	kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
-	kexAlgoECDH256  = "ecdh-sha2-nistp256"
-	kexAlgoECDH384  = "ecdh-sha2-nistp384"
-	kexAlgoECDH521  = "ecdh-sha2-nistp521"
+	kexAlgoDH1SHA1          = "diffie-hellman-group1-sha1"
+	kexAlgoDH14SHA1         = "diffie-hellman-group14-sha1"
+	kexAlgoECDH256          = "ecdh-sha2-nistp256"
+	kexAlgoECDH384          = "ecdh-sha2-nistp384"
+	kexAlgoECDH521          = "ecdh-sha2-nistp521"
+	kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org"
 )
 
 // kexResult captures the outcome of a key exchange.
@@ -383,4 +387,140 @@
 	kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
 	kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
 	kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
+	kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
+}
+
+// curve25519sha256 implements the curve25519-sha256@libssh.org key
+// agreement protocol, as described in
+// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
+type curve25519sha256 struct{}
+
+type curve25519KeyPair struct {
+	priv [32]byte
+	pub  [32]byte
+}
+
+func (kp *curve25519KeyPair) generate(rand io.Reader) error {
+	if _, err := io.ReadFull(rand, kp.priv[:]); err != nil {
+		return err
+	}
+	curve25519.ScalarBaseMult(&kp.pub, &kp.priv)
+	return nil
+}
+
+// curve25519Zeros is just an array of 32 zero bytes so that we have something
+// convenient to compare against in order to reject curve25519 points with the
+// wrong order.
+var curve25519Zeros [32]byte
+
+func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
+	var kp curve25519KeyPair
+	if err := kp.generate(rand); err != nil {
+		return nil, err
+	}
+	if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil {
+		return nil, err
+	}
+
+	packet, err := c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var reply kexECDHReplyMsg
+	if err = Unmarshal(packet, &reply); err != nil {
+		return nil, err
+	}
+	if len(reply.EphemeralPubKey) != 32 {
+		return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
+	}
+
+	var servPub, secret [32]byte
+	copy(servPub[:], reply.EphemeralPubKey)
+	curve25519.ScalarMult(&secret, &kp.priv, &servPub)
+	if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
+		return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
+	}
+
+	h := crypto.SHA256.New()
+	magics.write(h)
+	writeString(h, reply.HostKey)
+	writeString(h, kp.pub[:])
+	writeString(h, reply.EphemeralPubKey)
+
+	kInt := new(big.Int).SetBytes(secret[:])
+	K := make([]byte, intLength(kInt))
+	marshalInt(K, kInt)
+	h.Write(K)
+
+	return &kexResult{
+		H:         h.Sum(nil),
+		K:         K,
+		HostKey:   reply.HostKey,
+		Signature: reply.Signature,
+		Hash:      crypto.SHA256,
+	}, nil
+}
+
+func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+	packet, err := c.readPacket()
+	if err != nil {
+		return
+	}
+	var kexInit kexECDHInitMsg
+	if err = Unmarshal(packet, &kexInit); err != nil {
+		return
+	}
+
+	if len(kexInit.ClientPubKey) != 32 {
+		return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
+	}
+
+	var kp curve25519KeyPair
+	if err := kp.generate(rand); err != nil {
+		return nil, err
+	}
+
+	var clientPub, secret [32]byte
+	copy(clientPub[:], kexInit.ClientPubKey)
+	curve25519.ScalarMult(&secret, &kp.priv, &clientPub)
+	if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
+		return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
+	}
+
+	hostKeyBytes := priv.PublicKey().Marshal()
+
+	h := crypto.SHA256.New()
+	magics.write(h)
+	writeString(h, hostKeyBytes)
+	writeString(h, kexInit.ClientPubKey)
+	writeString(h, kp.pub[:])
+
+	kInt := new(big.Int).SetBytes(secret[:])
+	K := make([]byte, intLength(kInt))
+	marshalInt(K, kInt)
+	h.Write(K)
+
+	H := h.Sum(nil)
+
+	sig, err := signAndMarshal(priv, rand, H)
+	if err != nil {
+		return nil, err
+	}
+
+	reply := kexECDHReplyMsg{
+		EphemeralPubKey: kp.pub[:],
+		HostKey:         hostKeyBytes,
+		Signature:       sig,
+	}
+	if err := c.writePacket(Marshal(&reply)); err != nil {
+		return nil, err
+	}
+	return &kexResult{
+		H:         H,
+		K:         K,
+		HostKey:   hostKeyBytes,
+		Signature: sig,
+		Hash:      crypto.SHA256,
+	}, nil
 }