go.crypto/ssh: separate kex algorithms into kexAlgorithm class.

Adds readPacket() to conn, and renames conn to packetConn.
Key exchanges operate on packetConn, so they can be
unittested.

R=agl, jpsugar, dave
CC=golang-dev
https://golang.org/cl/13352055
diff --git a/ssh/channel.go b/ssh/channel.go
index be7b19f..336f7f7 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -78,7 +78,7 @@
 )
 
 type channel struct {
-	conn              // the underlying transport
+	packetConn        // the underlying transport
 	localId, remoteId uint32
 	remoteWin         window
 	maxPacket         uint32
@@ -102,7 +102,7 @@
 
 // sendClose informs the remote side of our intent to close the channel.
 func (c *channel) sendClose() error {
-	return c.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
+	return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
 		PeersId: c.remoteId,
 	}))
 }
@@ -124,7 +124,7 @@
 	if uint32(len(b)) > c.maxPacket {
 		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
 	}
-	return c.conn.writePacket(b)
+	return c.packetConn.writePacket(b)
 }
 
 func (c *channel) closed() bool {
@@ -447,12 +447,12 @@
 // newClientChan returns a partially constructed *clientChan
 // using the local id provided. To be usable clientChan.remoteId
 // needs to be assigned once known.
-func newClientChan(cc conn, id uint32) *clientChan {
+func newClientChan(cc packetConn, id uint32) *clientChan {
 	c := &clientChan{
 		channel: channel{
-			conn:      cc,
-			localId:   id,
-			remoteWin: window{Cond: newCond()},
+			packetConn: cc,
+			localId:    id,
+			remoteWin:  window{Cond: newCond()},
 		},
 		msg: make(chan interface{}, 16),
 	}
diff --git a/ssh/client.go b/ssh/client.go
index 6a5ec24..99e8377 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -5,15 +5,11 @@
 package ssh
 
 import (
-	"crypto"
-	"crypto/ecdsa"
-	"crypto/elliptic"
 	"crypto/rand"
 	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
-	"math/big"
 	"net"
 	"sync"
 )
@@ -63,17 +59,14 @@
 
 // handshake performs the client side key exchange. See RFC 4253 Section 7.
 func (c *ClientConn) handshake() error {
-	var magics handshakeMagics
-
-	var version []byte
+	var myVersion []byte
 	if len(c.config.ClientVersion) > 0 {
-		version = []byte(c.config.ClientVersion)
+		myVersion = []byte(c.config.ClientVersion)
 	} else {
-		version = clientVersion
+		myVersion = clientVersion
 	}
-	magics.clientVersion = version
-	version = append(version, '\r', '\n')
-	if _, err := c.Write(version); err != nil {
+
+	if _, err := c.Write(append(myVersion, '\r', '\n')); err != nil {
 		return err
 	}
 	if err := c.Flush(); err != nil {
@@ -81,12 +74,12 @@
 	}
 
 	// read remote server version
-	version, err := readVersion(c)
+	serverVersion, err := readVersion(c)
 	if err != nil {
 		return err
 	}
-	magics.serverVersion = version
-	c.serverVersion = string(version)
+	c.serverVersion = string(serverVersion)
+
 	clientKexInit := kexInitMsg{
 		KexAlgos:                c.config.Crypto.kexes(),
 		ServerHostKeyAlgos:      supportedHostKeyAlgos,
@@ -98,8 +91,6 @@
 		CompressionServerClient: supportedCompressions,
 	}
 	kexInitPacket := marshal(msgKexInit, clientKexInit)
-	magics.clientKexInit = kexInitPacket
-
 	if err := c.writePacket(kexInitPacket); err != nil {
 		return err
 	}
@@ -108,8 +99,6 @@
 		return err
 	}
 
-	magics.serverKexInit = packet
-
 	var serverKexInit kexInitMsg
 	if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
 		return err
@@ -128,23 +117,18 @@
 		}
 	}
 
-	var result *kexResult
-	switch kexAlgo {
-	case kexAlgoECDH256:
-		result, err = c.kexECDH(elliptic.P256(), &magics, hostKeyAlgo)
-	case kexAlgoECDH384:
-		result, err = c.kexECDH(elliptic.P384(), &magics, hostKeyAlgo)
-	case kexAlgoECDH521:
-		result, err = c.kexECDH(elliptic.P521(), &magics, hostKeyAlgo)
-	case kexAlgoDH14SHA1:
-		dhGroup14Once.Do(initDHGroup14)
-		result, err = c.kexDH(crypto.SHA1, dhGroup14, &magics, hostKeyAlgo)
-	case kexAlgoDH1SHA1:
-		dhGroup1Once.Do(initDHGroup1)
-		result, err = c.kexDH(crypto.SHA1, dhGroup1, &magics, hostKeyAlgo)
-	default:
-		err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+	kex, ok := kexAlgoMap[kexAlgo]
+	if !ok {
+		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
 	}
+
+	magics := handshakeMagics{
+		clientVersion: myVersion,
+		serverVersion: serverVersion,
+		clientKexInit: kexInitPacket,
+		serverKexInit: packet,
+	}
+	result, err := kex.Client(c, c.config.rand(), &magics)
 	if err != nil {
 		return err
 	}
@@ -164,7 +148,8 @@
 	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 	}
-	if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, result.Hash); err != nil {
+
+	if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
 		return err
 	}
 	if packet, err = c.readPacket(); err != nil {
@@ -173,72 +158,12 @@
 	if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
-	if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, result.Hash); err != nil {
+	if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
 		return err
 	}
 	return c.authenticate(result.H)
 }
 
-// kexECDH performs Elliptic Curve Diffie-Hellman key exchange as
-// described in RFC 5656, section 4.
-func (c *ClientConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) {
-	ephKey, err := ecdsa.GenerateKey(curve, c.config.rand())
-	if err != nil {
-		return nil, err
-	}
-
-	kexInit := kexECDHInitMsg{
-		ClientPubKey: elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
-	}
-
-	serialized := marshal(msgKexECDHInit, kexInit)
-	if err := c.writePacket(serialized); err != nil {
-		return nil, err
-	}
-
-	packet, err := c.readPacket()
-	if err != nil {
-		return nil, err
-	}
-
-	var reply kexECDHReplyMsg
-	if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
-		return nil, err
-	}
-
-	x, y := elliptic.Unmarshal(curve, reply.EphemeralPubKey)
-	if x == nil {
-		return nil, errors.New("ssh: elliptic.Unmarshal failure")
-	}
-	if !validateECPublicKey(curve, x, y) {
-		return nil, errors.New("ssh: ephemeral server key not on curve")
-	}
-
-	// generate shared secret
-	secret, _ := curve.ScalarMult(x, y, ephKey.D.Bytes())
-
-	hashFunc := ecHash(curve)
-	h := hashFunc.New()
-	writeString(h, magics.clientVersion)
-	writeString(h, magics.serverVersion)
-	writeString(h, magics.clientKexInit)
-	writeString(h, magics.serverKexInit)
-	writeString(h, reply.HostKey)
-	writeString(h, kexInit.ClientPubKey)
-	writeString(h, reply.EphemeralPubKey)
-	K := make([]byte, intLength(secret))
-	marshalInt(K, secret)
-	h.Write(K)
-
-	return &kexResult{
-		H:         h.Sum(nil),
-		K:         K,
-		HostKey:   reply.HostKey,
-		Signature: reply.Signature,
-		Hash:      hashFunc,
-	}, nil
-}
-
 // Verify the host key obtained in the key exchange.
 func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error {
 	hostKey, rest, ok := ParsePublicKey(hostKeyBytes)
@@ -260,74 +185,6 @@
 	return nil
 }
 
-// kexResult captures the outcome of a key exchange.
-type kexResult struct {
-	// Session hash. See also RFC 4253, section 8.
-	H []byte
-
-	// Shared secret. See also RFC 4253, section 8.
-	K []byte
-
-	// Host key as hashed into H
-	HostKey []byte
-
-	// Signature of H
-	Signature []byte
-
-	// Hash function that was used.
-	Hash crypto.Hash
-}
-
-// kexDH performs Diffie-Hellman key agreement on a ClientConn.
-func (c *ClientConn) kexDH(hashFunc crypto.Hash, group *dhGroup, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) {
-	x, err := rand.Int(c.config.rand(), group.p)
-	if err != nil {
-		return nil, err
-	}
-	X := new(big.Int).Exp(group.g, x, group.p)
-	kexDHInit := kexDHInitMsg{
-		X: X,
-	}
-	if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
-		return nil, err
-	}
-
-	packet, err := c.readPacket()
-	if err != nil {
-		return nil, err
-	}
-
-	var kexDHReply kexDHReplyMsg
-	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
-		return nil, err
-	}
-
-	kInt, err := group.diffieHellman(kexDHReply.Y, x)
-	if err != nil {
-		return nil, err
-	}
-
-	h := hashFunc.New()
-	writeString(h, magics.clientVersion)
-	writeString(h, magics.serverVersion)
-	writeString(h, magics.clientKexInit)
-	writeString(h, magics.serverKexInit)
-	writeString(h, kexDHReply.HostKey)
-	writeInt(h, X)
-	writeInt(h, kexDHReply.Y)
-	K := make([]byte, intLength(kInt))
-	marshalInt(K, kInt)
-	h.Write(K)
-
-	return &kexResult{
-		H:         h.Sum(nil),
-		K:         K,
-		HostKey:   kexDHReply.HostKey,
-		Signature: kexDHReply.Signature,
-		Hash:      hashFunc,
-	}, nil
-}
-
 // mainLoop reads incoming messages and routes channel messages
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
@@ -633,18 +490,18 @@
 }
 
 // Allocate a new ClientChan with the next avail local id.
-func (c *chanList) newChan(t *transport) *clientChan {
+func (c *chanList) newChan(p packetConn) *clientChan {
 	c.Lock()
 	defer c.Unlock()
 	for i := range c.chans {
 		if c.chans[i] == nil {
-			ch := newClientChan(t, uint32(i))
+			ch := newClientChan(p, uint32(i))
 			c.chans[i] = ch
 			return ch
 		}
 	}
 	i := len(c.chans)
-	ch := newClientChan(t, uint32(i))
+	ch := newClientChan(p, uint32(i))
 	c.chans = append(c.chans, ch)
 	return ch
 }
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 31a5e09..d2ed48e 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -81,7 +81,7 @@
 	// Returns true if authentication is successful.
 	// If authentication is not successful, a []string of alternative
 	// method names is returned.
-	auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
+	auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
 
 	// method returns the RFC 4252 method name.
 	method() string
@@ -90,8 +90,8 @@
 // "none" authentication, RFC 4252 section 5.2.
 type noneAuth int
 
-func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
-	if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
+func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+	if err := c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
 		User:    user,
 		Service: serviceSSH,
 		Method:  "none",
@@ -99,7 +99,7 @@
 		return false, nil, err
 	}
 
-	return handleAuthResponse(t)
+	return handleAuthResponse(c)
 }
 
 func (n *noneAuth) method() string {
@@ -111,7 +111,7 @@
 	ClientPassword
 }
 
-func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
 	type passwordAuthMsg struct {
 		User     string
 		Service  string
@@ -125,7 +125,7 @@
 		return false, nil, err
 	}
 
-	if err := t.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
+	if err := c.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
 		User:     user,
 		Service:  serviceSSH,
 		Method:   "password",
@@ -135,7 +135,7 @@
 		return false, nil, err
 	}
 
-	return handleAuthResponse(t)
+	return handleAuthResponse(c)
 }
 
 func (p *passwordAuth) method() string {
@@ -181,7 +181,7 @@
 	Sig []byte `ssh:"rest"`
 }
 
-func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
 	// Authentication is performed in two stages. The first stage sends an
 	// enquiry to test if each key is acceptable to the remote. The second
 	// stage attempts to authenticate with the valid keys obtained in the
@@ -200,7 +200,7 @@
 			break
 		}
 
-		if ok, err := p.validateKey(key, user, t); ok {
+		if ok, err := p.validateKey(key, user, c); ok {
 			validKeys[index] = key
 		} else {
 			if err != nil {
@@ -237,10 +237,10 @@
 			Sig:      sig,
 		}
 		p := marshal(msgUserAuthRequest, msg)
-		if err := t.writePacket(p); err != nil {
+		if err := c.writePacket(p); err != nil {
 			return false, nil, err
 		}
-		success, methods, err := handleAuthResponse(t)
+		success, methods, err := handleAuthResponse(c)
 		if err != nil {
 			return false, nil, err
 		}
@@ -252,7 +252,7 @@
 }
 
 // validateKey validates the key provided it is acceptable to the server.
-func (p *publickeyAuth) validateKey(key PublicKey, user string, t *transport) (bool, error) {
+func (p *publickeyAuth) validateKey(key PublicKey, user string, c packetConn) (bool, error) {
 	pubkey := MarshalPublicKey(key)
 	algoname := key.PublicKeyAlgo()
 	msg := publickeyAuthMsg{
@@ -263,19 +263,19 @@
 		Algoname: algoname,
 		Pubkey:   string(pubkey),
 	}
-	if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+	if err := c.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
 		return false, err
 	}
 
-	return p.confirmKeyAck(key, t)
+	return p.confirmKeyAck(key, c)
 }
 
-func (p *publickeyAuth) confirmKeyAck(key PublicKey, t *transport) (bool, error) {
+func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
 	pubkey := MarshalPublicKey(key)
 	algoname := key.PublicKeyAlgo()
 
 	for {
-		packet, err := t.readPacket()
+		packet, err := c.readPacket()
 		if err != nil {
 			return false, err
 		}
@@ -312,9 +312,9 @@
 // handleAuthResponse returns whether the preceding authentication request succeeded
 // along with a list of remaining authentication methods to try next and
 // an error if an unexpected response was received.
-func handleAuthResponse(t *transport) (bool, []string, error) {
+func handleAuthResponse(c packetConn) (bool, []string, error) {
 	for {
-		packet, err := t.readPacket()
+		packet, err := c.readPacket()
 		if err != nil {
 			return false, nil, err
 		}
@@ -411,11 +411,11 @@
 	ClientKeyboardInteractive
 }
 
-func (c *keyboardInteractiveAuth) method() string {
+func (k *keyboardInteractiveAuth) method() string {
 	return "keyboard-interactive"
 }
 
-func (c *keyboardInteractiveAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
 	type initiateMsg struct {
 		User       string
 		Service    string
@@ -424,7 +424,7 @@
 		Submethods string
 	}
 
-	if err := t.writePacket(marshal(msgUserAuthRequest, initiateMsg{
+	if err := c.writePacket(marshal(msgUserAuthRequest, initiateMsg{
 		User:    user,
 		Service: serviceSSH,
 		Method:  "keyboard-interactive",
@@ -433,7 +433,7 @@
 	}
 
 	for {
-		packet, err := t.readPacket()
+		packet, err := c.readPacket()
 		if err != nil {
 			return false, nil, err
 		}
@@ -480,7 +480,7 @@
 			return false, nil, fmt.Errorf("ssh: junk following message %q", rest)
 		}
 
-		answers, err := c.Challenge(msg.User, msg.Instruction, prompts, echos)
+		answers, err := k.Challenge(msg.User, msg.Instruction, prompts, echos)
 		if err != nil {
 			return false, nil, err
 		}
@@ -501,7 +501,7 @@
 			p = marshalString(p, []byte(a))
 		}
 
-		if err := t.writePacket(serialized); err != nil {
+		if err := c.writePacket(serialized); err != nil {
 			return false, nil, err
 		}
 	}
diff --git a/ssh/common.go b/ssh/common.go
index 7e6e5dc..e640000 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -6,9 +6,7 @@
 
 import (
 	"crypto"
-	"errors"
 	"fmt"
-	"math/big"
 	"sync"
 
 	_ "crypto/sha1"
@@ -18,11 +16,6 @@
 
 // These are string constants in the SSH protocol.
 const (
-	kexAlgoDH1SHA1  = "diffie-hellman-group1-sha1"
-	kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
-	kexAlgoECDH256  = "ecdh-sha2-nistp256"
-	kexAlgoECDH384  = "ecdh-sha2-nistp384"
-	kexAlgoECDH521  = "ecdh-sha2-nistp521"
 	hostAlgoRSA     = "ssh-rsa"
 	hostAlgoDSA     = "ssh-dss"
 	compressionNone = "none"
@@ -53,48 +46,6 @@
 	CertAlgoECDSA521v01: crypto.SHA512,
 }
 
-// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
-type dhGroup struct {
-	g, p *big.Int
-}
-
-func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
-	if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
-		return nil, errors.New("ssh: DH parameter out of bounds")
-	}
-	return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
-}
-
-// dhGroup1 is the group called diffie-hellman-group1-sha1 in RFC 4253 and
-// Oakley Group 2 in RFC 2409.
-var dhGroup1 *dhGroup
-
-var dhGroup1Once sync.Once
-
-func initDHGroup1() {
-	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
-
-	dhGroup1 = &dhGroup{
-		g: new(big.Int).SetInt64(2),
-		p: p,
-	}
-}
-
-// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
-// Oakley Group 14 in RFC 3526.
-var dhGroup14 *dhGroup
-
-var dhGroup14Once sync.Once
-
-func initDHGroup14() {
-	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
-
-	dhGroup14 = &dhGroup{
-		g: new(big.Int).SetInt64(2),
-		p: p,
-	}
-}
-
 // UnexpectedMessageError results when the SSH message that we received didn't
 // match what we wanted.
 type UnexpectedMessageError struct {
@@ -114,11 +65,6 @@
 	return fmt.Sprintf("ssh: parse error in message type %d", p.msgType)
 }
 
-type handshakeMagics struct {
-	clientVersion, serverVersion []byte
-	clientKexInit, serverKexInit []byte
-}
-
 func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
 	for _, clientAlgo := range clientAlgos {
 		for _, serverAlgo := range serverAlgos {
diff --git a/ssh/kex.go b/ssh/kex.go
new file mode 100644
index 0000000..e6ce6c3
--- /dev/null
+++ b/ssh/kex.go
@@ -0,0 +1,387 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"crypto"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"errors"
+	"io"
+	"math/big"
+)
+
+const (
+	kexAlgoDH1SHA1  = "diffie-hellman-group1-sha1"
+	kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
+	kexAlgoECDH256  = "ecdh-sha2-nistp256"
+	kexAlgoECDH384  = "ecdh-sha2-nistp384"
+	kexAlgoECDH521  = "ecdh-sha2-nistp521"
+)
+
+// kexResult captures the outcome of a key exchange.
+type kexResult struct {
+	// Session hash. See also RFC 4253, section 8.
+	H []byte
+
+	// Shared secret. See also RFC 4253, section 8.
+	K []byte
+
+	// Host key as hashed into H
+	HostKey []byte
+
+	// Signature of H
+	Signature []byte
+}
+
+// handshakeMagics contains data that is always included in the
+// session hash.
+type handshakeMagics struct {
+	clientVersion, serverVersion []byte
+	clientKexInit, serverKexInit []byte
+}
+
+func (m *handshakeMagics) write(w io.Writer) {
+	writeString(w, m.clientVersion)
+	writeString(w, m.serverVersion)
+	writeString(w, m.clientKexInit)
+	writeString(w, m.serverKexInit)
+}
+
+// kexAlgorithm abstracts different key exchange algorithms.
+type kexAlgorithm interface {
+	// Server runs server-side key agreement, signing the result
+	// with a hostkey.
+	Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error)
+
+	// Client runs the client-side key agreement. Caller is
+	// responsible for verifying the host key signature.
+	Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
+
+	// Hash returns a cryptographic hash function that matches the
+	// security level of the key exchange algorithm. It is used
+	// for calculating kexResult.H, and for deriving keys from
+	// data in kexResult.
+	Hash() crypto.Hash
+}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+	g, p *big.Int
+}
+
+func (group *dhGroup) Hash() crypto.Hash {
+	return crypto.SHA1
+}
+
+func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
+	if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
+		return nil, errors.New("ssh: DH parameter out of bounds")
+	}
+	return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
+}
+
+func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
+	hashFunc := crypto.SHA1
+
+	x, err := rand.Int(randSource, group.p)
+	if err != nil {
+		return nil, err
+	}
+	X := new(big.Int).Exp(group.g, x, group.p)
+	kexDHInit := kexDHInitMsg{
+		X: X,
+	}
+	if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
+		return nil, err
+	}
+
+	packet, err := c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var kexDHReply kexDHReplyMsg
+	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
+		return nil, err
+	}
+
+	kInt, err := group.diffieHellman(kexDHReply.Y, x)
+	if err != nil {
+		return nil, err
+	}
+
+	h := hashFunc.New()
+	magics.write(h)
+	writeString(h, kexDHReply.HostKey)
+	writeInt(h, X)
+	writeInt(h, kexDHReply.Y)
+	K := make([]byte, intLength(kInt))
+	marshalInt(K, kInt)
+	h.Write(K)
+
+	return &kexResult{
+		H:         h.Sum(nil),
+		K:         K,
+		HostKey:   kexDHReply.HostKey,
+		Signature: kexDHReply.Signature,
+	}, nil
+}
+
+func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+	hashFunc := crypto.SHA1
+	packet, err := c.readPacket()
+	if err != nil {
+		return
+	}
+	var kexDHInit kexDHInitMsg
+	if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+		return
+	}
+
+	y, err := rand.Int(randSource, group.p)
+	if err != nil {
+		return
+	}
+
+	Y := new(big.Int).Exp(group.g, y, group.p)
+	kInt, err := group.diffieHellman(kexDHInit.X, y)
+	if err != nil {
+		return nil, err
+	}
+
+	hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+
+	h := hashFunc.New()
+	magics.write(h)
+	writeString(h, hostKeyBytes)
+	writeInt(h, kexDHInit.X)
+	writeInt(h, Y)
+
+	K := make([]byte, intLength(kInt))
+	marshalInt(K, kInt)
+	h.Write(K)
+
+	H := h.Sum(nil)
+
+	// H is already a hash, but the hostkey signing will apply its
+	// own key-specific hash algorithm.
+	sig, err := signAndMarshal(priv, randSource, H)
+	if err != nil {
+		return nil, err
+	}
+
+	kexDHReply := kexDHReplyMsg{
+		HostKey:   hostKeyBytes,
+		Y:         Y,
+		Signature: sig,
+	}
+	packet = marshal(msgKexDHReply, kexDHReply)
+
+	err = c.writePacket(packet)
+	return &kexResult{
+		H:         H,
+		K:         K,
+		HostKey:   hostKeyBytes,
+		Signature: sig,
+	}, nil
+}
+
+// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
+// described in RFC 5656, section 4.
+type ecdh struct {
+	curve elliptic.Curve
+}
+
+func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
+	ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
+	if err != nil {
+		return nil, err
+	}
+
+	kexInit := kexECDHInitMsg{
+		ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
+	}
+
+	serialized := marshal(msgKexECDHInit, kexInit)
+	if err := c.writePacket(serialized); err != nil {
+		return nil, err
+	}
+
+	packet, err := c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var reply kexECDHReplyMsg
+	if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
+		return nil, err
+	}
+
+	x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
+	if err != nil {
+		return nil, err
+	}
+
+	// generate shared secret
+	secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
+
+	h := ecHash(kex.curve).New()
+	magics.write(h)
+	writeString(h, reply.HostKey)
+	writeString(h, kexInit.ClientPubKey)
+	writeString(h, reply.EphemeralPubKey)
+	K := make([]byte, intLength(secret))
+	marshalInt(K, secret)
+	h.Write(K)
+
+	return &kexResult{
+		H:         h.Sum(nil),
+		K:         K,
+		HostKey:   reply.HostKey,
+		Signature: reply.Signature,
+	}, nil
+}
+
+// unmarshalECKey parses and checks an EC key.
+func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
+	x, y = elliptic.Unmarshal(curve, pubkey)
+	if x == nil {
+		return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
+	}
+	if !validateECPublicKey(curve, x, y) {
+		return nil, nil, errors.New("ssh: public key not on curve")
+	}
+	return x, y, nil
+}
+
+// validateECPublicKey checks that the point is a valid public key for
+// the given curve. See [SEC1], 3.2.2
+func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
+	if x.Sign() == 0 && y.Sign() == 0 {
+		return false
+	}
+
+	if x.Cmp(curve.Params().P) >= 0 {
+		return false
+	}
+
+	if y.Cmp(curve.Params().P) >= 0 {
+		return false
+	}
+
+	if !curve.IsOnCurve(x, y) {
+		return false
+	}
+
+	// We don't check if N * PubKey == 0, since
+	//
+	// - the NIST curves have cofactor = 1, so this is implicit.
+	// (We don't forsee an implementation that supports non NIST
+	// curves)
+	//
+	// - for ephemeral keys, we don't need to worry about small
+	// subgroup attacks.
+	return true
+}
+
+func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+	packet, err := c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var kexECDHInit kexECDHInitMsg
+	if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
+		return nil, err
+	}
+
+	clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
+	if err != nil {
+		return nil, err
+	}
+
+	// We could cache this key across multiple users/multiple
+	// connection attempts, but the benefit is small. OpenSSH
+	// generates a new key for each incoming connection.
+	ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
+	if err != nil {
+		return nil, err
+	}
+
+	hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+
+	serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
+
+	// generate shared secret
+	secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
+
+	h := ecHash(kex.curve).New()
+	magics.write(h)
+	writeString(h, hostKeyBytes)
+	writeString(h, kexECDHInit.ClientPubKey)
+	writeString(h, serializedEphKey)
+
+	K := make([]byte, intLength(secret))
+	marshalInt(K, secret)
+	h.Write(K)
+
+	H := h.Sum(nil)
+
+	// H is already a hash, but the hostkey signing will apply its
+	// own key-specific hash algorithm.
+	sig, err := signAndMarshal(priv, rand, H)
+	if err != nil {
+		return nil, err
+	}
+
+	reply := kexECDHReplyMsg{
+		EphemeralPubKey: serializedEphKey,
+		HostKey:         hostKeyBytes,
+		Signature:       sig,
+	}
+
+	serialized := marshal(msgKexECDHReply, reply)
+	if err := c.writePacket(serialized); err != nil {
+		return nil, err
+	}
+
+	return &kexResult{
+		H:         H,
+		K:         K,
+		HostKey:   reply.HostKey,
+		Signature: sig,
+	}, nil
+}
+
+func (kex *ecdh) Hash() crypto.Hash {
+	return ecHash(kex.curve)
+}
+
+var kexAlgoMap = map[string]kexAlgorithm{}
+
+func init() {
+	// This is the group called diffie-hellman-group1-sha1 in RFC
+	// 4253 and Oakley Group 2 in RFC 2409.
+	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
+	kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
+		g: new(big.Int).SetInt64(2),
+		p: p,
+	}
+
+	// This is the group called diffie-hellman-group14-sha1 in RFC
+	// 4253 and Oakley Group 14 in RFC 3526.
+	p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+	kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
+		g: new(big.Int).SetInt64(2),
+		p: p,
+	}
+
+	kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
+	kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
+	kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
+}
diff --git a/ssh/kex_test.go b/ssh/kex_test.go
index 104009e..74e80c4 100644
--- a/ssh/kex_test.go
+++ b/ssh/kex_test.go
@@ -7,84 +7,73 @@
 // Key exchange tests.
 
 import (
-	"fmt"
-	"net"
+	"crypto/rand"
+	"io"
+	"reflect"
 	"testing"
 )
 
-func pipe() (net.Conn, net.Conn, error) {
-	l, err := net.Listen("tcp", "127.0.0.1:0")
-	if err != nil {
-		return nil, nil, err
-	}
-	conn1, err := net.Dial("tcp", l.Addr().String())
-	if err != nil {
-		return nil, nil, err
-	}
-
-	conn2, err := l.Accept()
-	if err != nil {
-		conn1.Close()
-		return nil, nil, err
-	}
-	l.Close()
-	return conn1, conn2, nil
+// An in-memory packetConn.
+type memTransport struct {
+	r, w chan []byte
 }
 
-func testKexAlgorithm(algo string) error {
-	crypto := CryptoConfig{
-		KeyExchanges: []string{algo},
-	}
-	serverConfig := ServerConfig{
-		PasswordCallback: func(conn *ServerConn, user, password string) bool {
-			return password == "password"
-		},
-		Crypto: crypto,
+func (t *memTransport) readPacket() ([]byte, error) {
+	p, ok := <-t.r
+	if !ok {
+		return nil, io.EOF
 	}
 
-	if err := serverConfig.SetRSAPrivateKey([]byte(testServerPrivateKey)); err != nil {
-		return fmt.Errorf("SetRSAPrivateKey: %v", err)
-	}
+	return p, nil
+}
 
-	clientConfig := ClientConfig{
-		User:   "user",
-		Auth:   []ClientAuth{ClientAuthPassword(password("password"))},
-		Crypto: crypto,
-	}
-
-	conn1, conn2, err := pipe()
-	if err != nil {
-		return err
-	}
-
-	defer conn1.Close()
-	defer conn2.Close()
-
-	server := Server(conn2, &serverConfig)
-	serverHS := make(chan error, 1)
-	go func() {
-		serverHS <- server.Handshake()
-	}()
-
-	// Client runs the handshake.
-	_, err = Client(conn1, &clientConfig)
-	if err != nil {
-		return fmt.Errorf("Client: %v", err)
-	}
-
-	if err := <-serverHS; err != nil {
-		return fmt.Errorf("server.Handshake: %v", err)
-	}
-
-	// Here we could check that we now can send data between client &
-	// server.
+func (t *memTransport) Close() error {
+	close(t.w)
 	return nil
 }
 
-func TestKexAlgorithms(t *testing.T) {
-	for _, algo := range []string{kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, kexAlgoDH1SHA1, kexAlgoDH14SHA1} {
-		if err := testKexAlgorithm(algo); err != nil {
-			t.Errorf("algorithm %s: %v", algo, err)
+func (t *memTransport) writePacket(p []byte) error {
+	t.w <- p
+	return nil
+}
+
+func memPipe() (a, b packetConn) {
+	p := make(chan []byte, 1)
+	q := make(chan []byte, 1)
+	return &memTransport{p, q}, &memTransport{q, p}
+}
+
+func TestKexes(t *testing.T) {
+	type kexResultErr struct {
+		result *kexResult
+		err    error
+	}
+
+	for name, kex := range kexAlgoMap {
+		a, b := memPipe()
+
+		s := make(chan kexResultErr, 1)
+		c := make(chan kexResultErr, 1)
+		var magics handshakeMagics
+		go func() {
+			r, e := kex.Client(a, rand.Reader, &magics)
+			c <- kexResultErr{r, e}
+		}()
+		go func() {
+			r, e := kex.Server(b, rand.Reader, &magics, ecdsaKey)
+			s <- kexResultErr{r, e}
+		}()
+
+		clientRes := <-c
+		serverRes := <-s
+		if clientRes.err != nil {
+			t.Errorf("client: %v", clientRes.err)
+		}
+		if serverRes.err != nil {
+			t.Errorf("server: %v", serverRes.err)
+		}
+		if !reflect.DeepEqual(clientRes.result, serverRes.result) {
+			t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
 		}
 	}
 }
diff --git a/ssh/server.go b/ssh/server.go
index ffc35dd..cb6fe8c 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -6,14 +6,11 @@
 
 import (
 	"bytes"
-	"crypto"
-	"crypto/ecdsa"
-	"crypto/elliptic"
 	"crypto/rand"
 	"encoding/binary"
 	"errors"
+	"fmt"
 	"io"
-	"math/big"
 	"net"
 	"sync"
 
@@ -140,177 +137,6 @@
 	}
 }
 
-// kexECDH performs Elliptic Curve Diffie-Hellman key agreement on a
-// ServerConnection, as documented in RFC 5656, section 4.
-func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
-	packet, err := s.readPacket()
-	if err != nil {
-		return
-	}
-
-	var kexECDHInit kexECDHInitMsg
-	if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
-		return
-	}
-
-	clientX, clientY := elliptic.Unmarshal(curve, kexECDHInit.ClientPubKey)
-	if clientX == nil {
-		return nil, errors.New("ssh: elliptic.Unmarshal failure")
-	}
-
-	if !validateECPublicKey(curve, clientX, clientY) {
-		return nil, errors.New("ssh: not a valid EC public key")
-	}
-
-	// We could cache this key across multiple users/multiple
-	// connection attempts, but the benefit is small. OpenSSH
-	// generates a new key for each incoming connection.
-	ephKey, err := ecdsa.GenerateKey(curve, s.config.rand())
-	if err != nil {
-		return nil, err
-	}
-
-	hostKeyBytes := MarshalPublicKey(priv.PublicKey())
-
-	serializedEphKey := elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
-
-	// generate shared secret
-	secret, _ := curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
-
-	hashFunc := ecHash(curve)
-	h := hashFunc.New()
-	writeString(h, magics.clientVersion)
-	writeString(h, magics.serverVersion)
-	writeString(h, magics.clientKexInit)
-	writeString(h, magics.serverKexInit)
-	writeString(h, hostKeyBytes)
-	writeString(h, kexECDHInit.ClientPubKey)
-	writeString(h, serializedEphKey)
-
-	K := make([]byte, intLength(secret))
-	marshalInt(K, secret)
-	h.Write(K)
-
-	H := h.Sum(nil)
-
-	// H is already a hash, but the hostkey signing will apply its
-	// own key specific hash algorithm.
-	sig, err := signAndMarshal(priv, s.config.rand(), H)
-	if err != nil {
-		return nil, err
-	}
-
-	reply := kexECDHReplyMsg{
-		EphemeralPubKey: serializedEphKey,
-		HostKey:         hostKeyBytes,
-		Signature:       sig,
-	}
-
-	serialized := marshal(msgKexECDHReply, reply)
-	if err := s.writePacket(serialized); err != nil {
-		return nil, err
-	}
-
-	return &kexResult{
-		H:       H,
-		K:       K,
-		HostKey: reply.HostKey,
-		Hash:    hashFunc,
-	}, nil
-}
-
-// validateECPublicKey checks that the point is a valid public key for
-// the given curve. See [SEC1], 3.2.2
-func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
-	if x.Sign() == 0 && y.Sign() == 0 {
-		return false
-	}
-
-	if x.Cmp(curve.Params().P) >= 0 {
-		return false
-	}
-
-	if y.Cmp(curve.Params().P) >= 0 {
-		return false
-	}
-
-	if !curve.IsOnCurve(x, y) {
-		return false
-	}
-
-	// We don't check if N * PubKey == 0, since
-	//
-	// - the NIST curves have cofactor = 1, so this is implicit.
-	// (We don't forsee an implementation that supports non NIST
-	// curves)
-	//
-	// - for ephemeral keys, we don't need to worry about small
-	// subgroup attacks.
-	return true
-}
-
-// kexDH performs Diffie-Hellman key agreement on a ServerConnection.
-func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
-	packet, err := s.readPacket()
-	if err != nil {
-		return
-	}
-	var kexDHInit kexDHInitMsg
-	if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
-		return
-	}
-
-	y, err := rand.Int(s.config.rand(), group.p)
-	if err != nil {
-		return
-	}
-
-	Y := new(big.Int).Exp(group.g, y, group.p)
-	kInt, err := group.diffieHellman(kexDHInit.X, y)
-	if err != nil {
-		return nil, err
-	}
-
-	hostKeyBytes := MarshalPublicKey(priv.PublicKey())
-
-	h := hashFunc.New()
-	writeString(h, magics.clientVersion)
-	writeString(h, magics.serverVersion)
-	writeString(h, magics.clientKexInit)
-	writeString(h, magics.serverKexInit)
-	writeString(h, hostKeyBytes)
-	writeInt(h, kexDHInit.X)
-	writeInt(h, Y)
-
-	K := make([]byte, intLength(kInt))
-	marshalInt(K, kInt)
-	h.Write(K)
-
-	H := h.Sum(nil)
-
-	// H is already a hash, but the hostkey signing will apply its
-	// own key specific hash algorithm.
-	sig, err := signAndMarshal(priv, s.config.rand(), H)
-	if err != nil {
-		return nil, err
-	}
-
-	kexDHReply := kexDHReplyMsg{
-		HostKey:   hostKeyBytes,
-		Y:         Y,
-		Signature: sig,
-	}
-	packet = marshal(msgKexDHReply, kexDHReply)
-
-	err = s.writePacket(packet)
-	return &kexResult{
-		H:       H,
-		K:       K,
-		HostKey: hostKeyBytes,
-		Hash:    hashFunc,
-	}, nil
-}
-
 // signAndMarshal signs the data with the appropriate algorithm,
 // and serializes the result in SSH wire format.
 func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
@@ -330,8 +156,8 @@
 	if _, err = s.Write(serverVersion); err != nil {
 		return
 	}
-	if err = s.Flush(); err != nil {
-		return
+	if err := s.Flush(); err != nil {
+		return err
 	}
 
 	s.ClientVersion, err = readVersion(s)
@@ -415,32 +241,22 @@
 		}
 	}
 
-	var magics handshakeMagics
-	magics.serverVersion = serverVersion[:len(serverVersion)-2]
-	magics.clientVersion = s.ClientVersion
-	magics.serverKexInit = marshal(msgKexInit, serverKexInit)
-	magics.clientKexInit = clientKexInitPacket
+	kex, ok := kexAlgoMap[kexAlgo]
+	if !ok {
+		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+	}
 
-	var result *kexResult
-	switch kexAlgo {
-	case kexAlgoECDH256:
-		result, err = s.kexECDH(elliptic.P256(), &magics, hostKey)
-	case kexAlgoECDH384:
-		result, err = s.kexECDH(elliptic.P384(), &magics, hostKey)
-	case kexAlgoECDH521:
-		result, err = s.kexECDH(elliptic.P521(), &magics, hostKey)
-	case kexAlgoDH14SHA1:
-		dhGroup14Once.Do(initDHGroup14)
-		result, err = s.kexDH(dhGroup14, crypto.SHA1, &magics, hostKey)
-	case kexAlgoDH1SHA1:
-		dhGroup1Once.Do(initDHGroup1)
-		result, err = s.kexDH(dhGroup1, crypto.SHA1, &magics, hostKey)
-	default:
-		err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
+	magics := handshakeMagics{
+		serverVersion: serverVersion[:len(serverVersion)-2],
+		clientVersion: s.ClientVersion,
+		serverKexInit: marshal(msgKexInit, serverKexInit),
+		clientKexInit: clientKexInitPacket,
 	}
+	result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
 	if err != nil {
-		return
+		return err
 	}
+
 	// sessionId must only be assigned during initial handshake.
 	if s.sessionId == nil {
 		s.sessionId = result.H
@@ -451,7 +267,7 @@
 	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
 		return
 	}
-	if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, result.Hash); err != nil {
+	if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
 		return
 	}
 
@@ -461,7 +277,7 @@
 	if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
-	if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, result.Hash); err != nil {
+	if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
 		return
 	}
 
@@ -760,10 +576,10 @@
 				}
 				c := &serverChan{
 					channel: channel{
-						conn:      s,
-						remoteId:  msg.PeersId,
-						remoteWin: window{Cond: newCond()},
-						maxPacket: msg.MaxPacketSize,
+						packetConn: s,
+						remoteId:   msg.PeersId,
+						remoteWin:  window{Cond: newCond()},
+						maxPacket:  msg.MaxPacketSize,
 					},
 					chanType:    msg.ChanType,
 					extraData:   msg.TypeSpecificData,
diff --git a/ssh/transport.go b/ssh/transport.go
index c222caf..bbc4d80 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -29,13 +29,16 @@
 	maxPacket = 256 * 1024
 )
 
-// conn represents an ssh transport that implements packet based
+// packetConn represents a transport that implements packet based
 // operations.
-type conn interface {
+type packetConn interface {
 	// Encrypt and send a packet of data to the remote peer.
 	writePacket(packet []byte) error
 
-	// Close closes the connection.
+	// Read a packet from the connection
+	readPacket() ([]byte, error)
+
+	// Close closes the write-side of the connection.
 	Close() error
 }
 
@@ -74,7 +77,7 @@
 }
 
 // Read and decrypt a single packet from the remote peer.
-func (r *reader) readOnePacket() ([]byte, error) {
+func (r *reader) readPacket() ([]byte, error) {
 	var lengthBytes = make([]byte, 5)
 	var macSize uint32
 	if _, err := io.ReadFull(r, lengthBytes); err != nil {
@@ -128,7 +131,7 @@
 // Read and decrypt next packet discarding debug and noop messages.
 func (t *transport) readPacket() ([]byte, error) {
 	for {
-		packet, err := t.readOnePacket()
+		packet, err := t.reader.readPacket()
 		if err != nil {
 			return nil, err
 		}