ssh: add diffie-hellman-group-exchange-sha256

Add the diffie-hellman-group-exchange-sha256 defined in RFC 4419 to
the list of supported key exchange algorithms for ssh.
The server half is only a minimal implementation to satisfy the automated tests.

Fixes golang/go#17230

Change-Id: I25880a564347fd9b4738dd2ed1e347cd5d2e21bb
GitHub-Last-Rev: 9f0b8d02c0c96e9baf00cdf1cf063ff834245443
GitHub-Pull-Request: golang/crypto#87
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/174257
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index 9200cb3..63a8e22 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -316,7 +316,7 @@
 			PublicKeys(),
 		},
 		Config: Config{
-			KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
+			KeyExchanges: []string{"non-existent-kex"},
 		},
 		HostKeyCallback: InsecureIgnoreHostKey(),
 	}
diff --git a/ssh/common.go b/ssh/common.go
index d97415d..e55fe0a 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -51,6 +51,13 @@
 	kexAlgoDH14SHA1, kexAlgoDH1SHA1,
 }
 
+// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
+// for the server half.
+var serverForbiddenKexAlgos = map[string]struct{}{
+	kexAlgoDHGEXSHA1:   {}, // server half implementation is only minimal to satisfy the automated tests
+	kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
+}
+
 // supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
 // of authenticating servers) in preference order.
 var supportedHostKeyAlgos = []string{
diff --git a/ssh/kex.go b/ssh/kex.go
index f34bcc0..1607200 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -10,7 +10,9 @@
 	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/subtle"
+	"encoding/binary"
 	"errors"
+	"fmt"
 	"io"
 	"math/big"
 
@@ -24,6 +26,12 @@
 	kexAlgoECDH384          = "ecdh-sha2-nistp384"
 	kexAlgoECDH521          = "ecdh-sha2-nistp521"
 	kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org"
+
+	// For the following kex only the client half contains a production
+	// ready implementation. The server half only consists of a minimal
+	// implementation to satisfy the automated tests.
+	kexAlgoDHGEXSHA1   = "diffie-hellman-group-exchange-sha1"
+	kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
 )
 
 // kexResult captures the outcome of a key exchange.
@@ -402,6 +410,8 @@
 	kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
 	kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
 	kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
+	kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
+	kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
 }
 
 // curve25519sha256 implements the curve25519-sha256@libssh.org key
@@ -538,3 +548,242 @@
 		Hash:      crypto.SHA256,
 	}, nil
 }
+
+// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and
+// diffie-hellman-group-exchange-sha256 key agreement protocols,
+// as described in RFC 4419
+type dhGEXSHA struct {
+	g, p     *big.Int
+	hashFunc crypto.Hash
+}
+
+const numMRTests = 64
+
+const (
+	dhGroupExchangeMinimumBits   = 2048
+	dhGroupExchangePreferredBits = 2048
+	dhGroupExchangeMaximumBits   = 8192
+)
+
+func (gex *dhGEXSHA) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
+	if theirPublic.Sign() <= 0 || theirPublic.Cmp(gex.p) >= 0 {
+		return nil, fmt.Errorf("ssh: DH parameter out of bounds")
+	}
+	return new(big.Int).Exp(theirPublic, myPrivate, gex.p), nil
+}
+
+func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
+	// Send GexRequest
+	kexDHGexRequest := kexDHGexRequestMsg{
+		MinBits:      dhGroupExchangeMinimumBits,
+		PreferedBits: dhGroupExchangePreferredBits,
+		MaxBits:      dhGroupExchangeMaximumBits,
+	}
+	if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil {
+		return nil, err
+	}
+
+	// Receive GexGroup
+	packet, err := c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var kexDHGexGroup kexDHGexGroupMsg
+	if err = Unmarshal(packet, &kexDHGexGroup); err != nil {
+		return nil, err
+	}
+
+	// reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits
+	if kexDHGexGroup.P.BitLen() < dhGroupExchangeMinimumBits || kexDHGexGroup.P.BitLen() > dhGroupExchangeMaximumBits {
+		return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", kexDHGexGroup.P.BitLen())
+	}
+
+	gex.p = kexDHGexGroup.P
+	gex.g = kexDHGexGroup.G
+
+	// Check if p is safe by verifing that p and (p-1)/2 are primes
+	one := big.NewInt(1)
+	var pHalf = &big.Int{}
+	pHalf.Rsh(gex.p, 1)
+	if !gex.p.ProbablyPrime(numMRTests) || !pHalf.ProbablyPrime(numMRTests) {
+		return nil, fmt.Errorf("ssh: server provided gex p is not safe")
+	}
+
+	// Check if g is safe by verifing that g > 1 and g < p - 1
+	var pMinusOne = &big.Int{}
+	pMinusOne.Sub(gex.p, one)
+	if gex.g.Cmp(one) != 1 && gex.g.Cmp(pMinusOne) != -1 {
+		return nil, fmt.Errorf("ssh: server provided gex g is not safe")
+	}
+
+	// Send GexInit
+	x, err := rand.Int(randSource, pHalf)
+	if err != nil {
+		return nil, err
+	}
+	X := new(big.Int).Exp(gex.g, x, gex.p)
+	kexDHGexInit := kexDHGexInitMsg{
+		X: X,
+	}
+	if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil {
+		return nil, err
+	}
+
+	// Receive GexReply
+	packet, err = c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var kexDHGexReply kexDHGexReplyMsg
+	if err = Unmarshal(packet, &kexDHGexReply); err != nil {
+		return nil, err
+	}
+
+	kInt, err := gex.diffieHellman(kexDHGexReply.Y, x)
+	if err != nil {
+		return nil, err
+	}
+
+	// Check if k is safe by verifing that k > 1 and k < p - 1
+	if kInt.Cmp(one) != 1 && kInt.Cmp(pMinusOne) != -1 {
+		return nil, fmt.Errorf("ssh: derived k is not safe")
+	}
+
+	h := gex.hashFunc.New()
+	magics.write(h)
+	writeString(h, kexDHGexReply.HostKey)
+	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
+	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
+	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
+	writeInt(h, gex.p)
+	writeInt(h, gex.g)
+	writeInt(h, X)
+	writeInt(h, kexDHGexReply.Y)
+	K := make([]byte, intLength(kInt))
+	marshalInt(K, kInt)
+	h.Write(K)
+
+	return &kexResult{
+		H:         h.Sum(nil),
+		K:         K,
+		HostKey:   kexDHGexReply.HostKey,
+		Signature: kexDHGexReply.Signature,
+		Hash:      gex.hashFunc,
+	}, nil
+}
+
+// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
+//
+// This is a minimal implementation to satisfy the automated tests.
+func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+	// Receive GexRequest
+	packet, err := c.readPacket()
+	if err != nil {
+		return
+	}
+	var kexDHGexRequest kexDHGexRequestMsg
+	if err = Unmarshal(packet, &kexDHGexRequest); err != nil {
+		return
+	}
+
+	// smoosh the user's preferred size into our own limits
+	if kexDHGexRequest.PreferedBits > dhGroupExchangeMaximumBits {
+		kexDHGexRequest.PreferedBits = dhGroupExchangeMaximumBits
+	}
+	if kexDHGexRequest.PreferedBits < dhGroupExchangeMinimumBits {
+		kexDHGexRequest.PreferedBits = dhGroupExchangeMinimumBits
+	}
+	// fix min/max if they're inconsistent.  technically, we could just pout
+	// and hang up, but there's no harm in giving them the benefit of the
+	// doubt and just picking a bitsize for them.
+	if kexDHGexRequest.MinBits > kexDHGexRequest.PreferedBits {
+		kexDHGexRequest.MinBits = kexDHGexRequest.PreferedBits
+	}
+	if kexDHGexRequest.MaxBits < kexDHGexRequest.PreferedBits {
+		kexDHGexRequest.MaxBits = kexDHGexRequest.PreferedBits
+	}
+
+	// Send GexGroup
+	// This is the group called diffie-hellman-group14-sha1 in RFC
+	// 4253 and Oakley Group 14 in RFC 3526.
+	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+	gex.p = p
+	gex.g = big.NewInt(2)
+
+	kexDHGexGroup := kexDHGexGroupMsg{
+		P: gex.p,
+		G: gex.g,
+	}
+	if err := c.writePacket(Marshal(&kexDHGexGroup)); err != nil {
+		return nil, err
+	}
+
+	// Receive GexInit
+	packet, err = c.readPacket()
+	if err != nil {
+		return
+	}
+	var kexDHGexInit kexDHGexInitMsg
+	if err = Unmarshal(packet, &kexDHGexInit); err != nil {
+		return
+	}
+
+	var pHalf = &big.Int{}
+	pHalf.Rsh(gex.p, 1)
+
+	y, err := rand.Int(randSource, pHalf)
+	if err != nil {
+		return
+	}
+
+	Y := new(big.Int).Exp(gex.g, y, gex.p)
+	kInt, err := gex.diffieHellman(kexDHGexInit.X, y)
+	if err != nil {
+		return nil, err
+	}
+
+	hostKeyBytes := priv.PublicKey().Marshal()
+
+	h := gex.hashFunc.New()
+	magics.write(h)
+	writeString(h, hostKeyBytes)
+	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
+	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
+	binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
+	writeInt(h, gex.p)
+	writeInt(h, gex.g)
+	writeInt(h, kexDHGexInit.X)
+	writeInt(h, Y)
+
+	K := make([]byte, intLength(kInt))
+	marshalInt(K, kInt)
+	h.Write(K)
+
+	H := h.Sum(nil)
+
+	// H is already a hash, but the hostkey signing will apply its
+	// own key-specific hash algorithm.
+	sig, err := signAndMarshal(priv, randSource, H)
+	if err != nil {
+		return nil, err
+	}
+
+	kexDHGexReply := kexDHGexReplyMsg{
+		HostKey:   hostKeyBytes,
+		Y:         Y,
+		Signature: sig,
+	}
+	packet = Marshal(&kexDHGexReply)
+
+	err = c.writePacket(packet)
+
+	return &kexResult{
+		H:         H,
+		K:         K,
+		HostKey:   hostKeyBytes,
+		Signature: sig,
+		Hash:      gex.hashFunc,
+	}, err
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index db914d8..ac41a41 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -97,6 +97,36 @@
 	Signature []byte
 }
 
+// See RFC 4419, section 5.
+const msgKexDHGexGroup = 31
+
+type kexDHGexGroupMsg struct {
+	P *big.Int `sshtype:"31"`
+	G *big.Int
+}
+
+const msgKexDHGexInit = 32
+
+type kexDHGexInitMsg struct {
+	X *big.Int `sshtype:"32"`
+}
+
+const msgKexDHGexReply = 33
+
+type kexDHGexReplyMsg struct {
+	HostKey   []byte `sshtype:"33"`
+	Y         *big.Int
+	Signature []byte
+}
+
+const msgKexDHGexRequest = 34
+
+type kexDHGexRequestMsg struct {
+	MinBits      uint32 `sshtype:"34"`
+	PreferedBits uint32
+	MaxBits      uint32
+}
+
 // See RFC 4253, section 10.
 const msgServiceRequest = 5
 
diff --git a/ssh/server.go b/ssh/server.go
index ac7f807..7a5a1d7 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -193,6 +193,12 @@
 	if fullConf.MaxAuthTries == 0 {
 		fullConf.MaxAuthTries = 6
 	}
+	// Check if the config contains any unsupported key exchanges
+	for _, kex := range fullConf.KeyExchanges {
+		if _, ok := serverForbiddenKexAlgos[kex]; ok {
+			return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
+		}
+	}
 
 	s := &connection{
 		sshConn: sshConn{conn: c},
diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go
index dcd2249..004b320 100644
--- a/ssh/test/session_test.go
+++ b/ssh/test/session_test.go
@@ -420,6 +420,11 @@
 	var config ssh.Config
 	config.SetDefaults()
 	kexOrder := config.KeyExchanges
+	// Based on the discussion in #17230, the key exchange algorithms
+	// diffie-hellman-group-exchange-sha1 and diffie-hellman-group-exchange-sha256
+	// are not included in the default list of supported kex so we have to add them
+	// here manually.
+	kexOrder = append(kexOrder, "diffie-hellman-group-exchange-sha1", "diffie-hellman-group-exchange-sha256")
 	for _, kex := range kexOrder {
 		t.Run(kex, func(t *testing.T) {
 			server := newServer(t)