go.crypto/ssh: import gosshnew.

See https://groups.google.com/d/msg/Golang-nuts/AoVxQ4bB5XQ/i8kpMxdbVlEJ

R=hanwen
CC=golang-codereviews
https://golang.org/cl/86190043
diff --git a/ssh/agent/client.go b/ssh/agent/client.go
new file mode 100644
index 0000000..9c11d32
--- /dev/null
+++ b/ssh/agent/client.go
@@ -0,0 +1,563 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+  Package agent implements a client to an ssh-agent daemon.
+
+References:
+  [PROTOCOL.agent]:    http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
+*/
+package agent
+
+import (
+	"bytes"
+	"crypto/dsa"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rsa"
+	"encoding/base64"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+	"math/big"
+	"sync"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+// Agent represents the capabilities of an ssh-agent.
+type Agent interface {
+	// List returns the identities known to the agent.
+	List() ([]*Key, error)
+
+	// Sign has the agent sign the data using a protocol 2 key as defined
+	// in [PROTOCOL.agent] section 2.6.2.
+	Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
+
+	// Insert adds a private key to the agent. If a certificate
+	// is given, that certificate is added as public key.
+	Add(s interface{}, cert *ssh.Certificate, comment string) error
+
+	// Remove removes all identities with the given public key.
+	Remove(key ssh.PublicKey) error
+
+	// RemoveAll removes all identities.
+	RemoveAll() error
+
+	// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
+	Lock(passphrase []byte) error
+
+	// Unlock undoes the effect of Lock
+	Unlock(passphrase []byte) error
+
+	// Signers returns signers for all the known keys.
+	Signers() ([]ssh.Signer, error)
+}
+
+// See [PROTOCOL.agent], section 3.
+const (
+	agentRequestV1Identities = 1
+
+	// 3.2 Requests from client to agent for protocol 2 key operations
+	agentAddIdentity         = 17
+	agentRemoveIdentity      = 18
+	agentRemoveAllIdentities = 19
+	agentAddIdConstrained    = 25
+
+	// 3.3 Key-type independent requests from client to agent
+	agentAddSmartcardKey            = 20
+	agentRemoveSmartcardKey         = 21
+	agentLock                       = 22
+	agentUnlock                     = 23
+	agentAddSmartcardKeyConstrained = 26
+
+	// 3.7 Key constraint identifiers
+	agentConstrainLifetime = 1
+	agentConstrainConfirm  = 2
+)
+
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
+// is a sanity check, not a limit in the spec.
+const maxAgentResponseBytes = 16 << 20
+
+// Agent messages:
+// These structures mirror the wire format of the corresponding ssh agent
+// messages found in [PROTOCOL.agent].
+
+// 3.4 Generic replies from agent to client
+const agentFailure = 5
+
+type failureAgentMsg struct{}
+
+const agentSuccess = 6
+
+type successAgentMsg struct{}
+
+// See [PROTOCOL.agent], section 2.5.2.
+const agentRequestIdentities = 11
+
+type requestIdentitiesAgentMsg struct{}
+
+// See [PROTOCOL.agent], section 2.5.2.
+const agentIdentitiesAnswer = 12
+
+type identitiesAnswerAgentMsg struct {
+	NumKeys uint32 `sshtype:"12"`
+	Keys    []byte `ssh:"rest"`
+}
+
+// See [PROTOCOL.agent], section 2.6.2.
+const agentSignRequest = 13
+
+type signRequestAgentMsg struct {
+	KeyBlob []byte `sshtype:"13"`
+	Data    []byte
+	Flags   uint32
+}
+
+// See [PROTOCOL.agent], section 2.6.2.
+
+// 3.6 Replies from agent to client for protocol 2 key operations
+const agentSignResponse = 14
+
+type signResponseAgentMsg struct {
+	SigBlob []byte `sshtype:"14"`
+}
+
+type publicKey struct {
+	Format string
+	Rest   []byte `ssh:"rest"`
+}
+
+// Key represents a protocol 2 public key as defined in
+// [PROTOCOL.agent], section 2.5.2.
+type Key struct {
+	Format  string
+	Blob    []byte
+	Comment string
+}
+
+func clientErr(err error) error {
+	return fmt.Errorf("agent: client error: %v", err)
+}
+
+// String returns the storage form of an agent key with the format, base64
+// encoded serialized key, and the comment if it is not empty.
+func (k *Key) String() string {
+	s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
+
+	if k.Comment != "" {
+		s += " " + k.Comment
+	}
+
+	return s
+}
+
+// Type returns the public key type.
+func (k *Key) Type() string {
+	return k.Format
+}
+
+// Marshal returns key blob to satisfy the ssh.PublicKey interface.
+func (k *Key) Marshal() []byte {
+	return k.Blob
+}
+
+// Verify satisfies the ssh.PublicKey interface, but is not
+// implemented for agent keys.
+func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
+	return errors.New("agent: agent key does not know how to verify")
+}
+
+type wireKey struct {
+	Format string
+	Rest   []byte `ssh:"rest"`
+}
+
+func parseKey(in []byte) (out *Key, rest []byte, err error) {
+	var record struct {
+		Blob    []byte
+		Comment string
+		Rest    []byte `ssh:"rest"`
+	}
+
+	if err := ssh.Unmarshal(in, &record); err != nil {
+		return nil, nil, err
+	}
+
+	var wk wireKey
+	if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
+		return nil, nil, err
+	}
+
+	return &Key{
+		Format:  wk.Format,
+		Blob:    record.Blob,
+		Comment: record.Comment,
+	}, record.Rest, nil
+}
+
+// client is a client for an ssh-agent process.
+type client struct {
+	// conn is typically a *net.UnixConn
+	conn io.ReadWriter
+	// mu is used to prevent concurrent access to the agent
+	mu sync.Mutex
+}
+
+// NewClient returns an Agent that talks to an ssh-agent process over
+// the given connection.
+func NewClient(rw io.ReadWriter) Agent {
+	return &client{conn: rw}
+}
+
+// call sends an RPC to the agent. On success, the reply is
+// unmarshaled into reply and replyType is set to the first byte of
+// the reply, which contains the type of the message.
+func (c *client) call(req []byte) (reply interface{}, err error) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	msg := make([]byte, 4+len(req))
+	binary.BigEndian.PutUint32(msg, uint32(len(req)))
+	copy(msg[4:], req)
+	if _, err = c.conn.Write(msg); err != nil {
+		return nil, clientErr(err)
+	}
+
+	var respSizeBuf [4]byte
+	if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
+		return nil, clientErr(err)
+	}
+	respSize := binary.BigEndian.Uint32(respSizeBuf[:])
+	if respSize > maxAgentResponseBytes {
+		return nil, clientErr(err)
+	}
+
+	buf := make([]byte, respSize)
+	if _, err = io.ReadFull(c.conn, buf); err != nil {
+		return nil, clientErr(err)
+	}
+	reply, err = unmarshal(buf)
+	if err != nil {
+		return nil, clientErr(err)
+	}
+	return reply, err
+}
+
+func (c *client) simpleCall(req []byte) error {
+	resp, err := c.call(req)
+	if err != nil {
+		return err
+	}
+	if _, ok := resp.(*successAgentMsg); ok {
+		return nil
+	}
+	return errors.New("agent: failure")
+}
+
+func (c *client) RemoveAll() error {
+	return c.simpleCall([]byte{agentRemoveAllIdentities})
+}
+
+func (c *client) Remove(key ssh.PublicKey) error {
+	req := ssh.Marshal(&agentRemoveIdentityMsg{
+		KeyBlob: key.Marshal(),
+	})
+	return c.simpleCall(req)
+}
+
+func (c *client) Lock(passphrase []byte) error {
+	req := ssh.Marshal(&agentLockMsg{
+		Passphrase: passphrase,
+	})
+	return c.simpleCall(req)
+}
+
+func (c *client) Unlock(passphrase []byte) error {
+	req := ssh.Marshal(&agentUnlockMsg{
+		Passphrase: passphrase,
+	})
+	return c.simpleCall(req)
+}
+
+// List returns the identities known to the agent.
+func (c *client) List() ([]*Key, error) {
+	// see [PROTOCOL.agent] section 2.5.2.
+	req := []byte{agentRequestIdentities}
+
+	msg, err := c.call(req)
+	if err != nil {
+		return nil, err
+	}
+
+	switch msg := msg.(type) {
+	case *identitiesAnswerAgentMsg:
+		if msg.NumKeys > maxAgentResponseBytes/8 {
+			return nil, errors.New("ssh: too many keys in agent reply")
+		}
+		keys := make([]*Key, msg.NumKeys)
+		data := msg.Keys
+		for i := uint32(0); i < msg.NumKeys; i++ {
+			var key *Key
+			var err error
+			if key, data, err = parseKey(data); err != nil {
+				return nil, err
+			}
+			keys[i] = key
+		}
+		return keys, nil
+	case *failureAgentMsg:
+		return nil, errors.New("ssh: failed to list keys")
+	}
+	panic("unreachable")
+}
+
+// Sign has the agent sign the data using a protocol 2 key as defined
+// in [PROTOCOL.agent] section 2.6.2.
+func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+	req := ssh.Marshal(signRequestAgentMsg{
+		KeyBlob: key.Marshal(),
+		Data:    data,
+	})
+
+	msg, err := c.call(req)
+	if err != nil {
+		return nil, err
+	}
+
+	switch msg := msg.(type) {
+	case *signResponseAgentMsg:
+		var sig ssh.Signature
+		if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
+			return nil, err
+		}
+
+		return &sig, nil
+	case *failureAgentMsg:
+		return nil, errors.New("ssh: failed to sign challenge")
+	}
+	panic("unreachable")
+}
+
+// unmarshal parses an agent message in packet, returning the parsed
+// form and the message type of packet.
+func unmarshal(packet []byte) (interface{}, error) {
+	if len(packet) < 1 {
+		return nil, errors.New("agent: empty packet")
+	}
+	var msg interface{}
+	switch packet[0] {
+	case agentFailure:
+		return new(failureAgentMsg), nil
+	case agentSuccess:
+		return new(successAgentMsg), nil
+	case agentIdentitiesAnswer:
+		msg = new(identitiesAnswerAgentMsg)
+	case agentSignResponse:
+		msg = new(signResponseAgentMsg)
+	default:
+		return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
+	}
+	if err := ssh.Unmarshal(packet, msg); err != nil {
+		return nil, err
+	}
+	return msg, nil
+}
+
+type rsaKeyMsg struct {
+	Type     string `sshtype:"17"`
+	N        *big.Int
+	E        *big.Int
+	D        *big.Int
+	Iqmp     *big.Int // IQMP = Inverse Q Mod P
+	P        *big.Int
+	Q        *big.Int
+	Comments string
+}
+
+type dsaKeyMsg struct {
+	Type     string `sshtype:"17"`
+	P        *big.Int
+	Q        *big.Int
+	G        *big.Int
+	Y        *big.Int
+	X        *big.Int
+	Comments string
+}
+
+type ecdsaKeyMsg struct {
+	Type     string `sshtype:"17"`
+	Curve    string
+	KeyBytes []byte
+	D        *big.Int
+	Comments string
+}
+
+// Insert adds a private key to the agent.
+func (c *client) insertKey(s interface{}, comment string) error {
+	var req []byte
+	switch k := s.(type) {
+	case *rsa.PrivateKey:
+		if len(k.Primes) != 2 {
+			return fmt.Errorf("ssh: unsupported RSA key with %d primes", len(k.Primes))
+		}
+		k.Precompute()
+		req = ssh.Marshal(rsaKeyMsg{
+			Type:     ssh.KeyAlgoRSA,
+			N:        k.N,
+			E:        big.NewInt(int64(k.E)),
+			D:        k.D,
+			Iqmp:     k.Precomputed.Qinv,
+			P:        k.Primes[0],
+			Q:        k.Primes[1],
+			Comments: comment,
+		})
+	case *dsa.PrivateKey:
+		req = ssh.Marshal(dsaKeyMsg{
+			Type:     ssh.KeyAlgoDSA,
+			P:        k.P,
+			Q:        k.Q,
+			G:        k.G,
+			Y:        k.Y,
+			X:        k.X,
+			Comments: comment,
+		})
+	case *ecdsa.PrivateKey:
+		nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
+		req = ssh.Marshal(ecdsaKeyMsg{
+			Type:     "ecdsa-sha2-" + nistID,
+			Curve:    nistID,
+			KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
+			D:        k.D,
+			Comments: comment,
+		})
+	default:
+		return fmt.Errorf("ssh: unsupported key type %T", s)
+	}
+	resp, err := c.call(req)
+	if err != nil {
+		return err
+	}
+	if _, ok := resp.(*successAgentMsg); ok {
+		return nil
+	}
+	return errors.New("ssh: failure")
+}
+
+type rsaCertMsg struct {
+	Type      string `sshtype:"17"`
+	CertBytes []byte
+	D         *big.Int
+	Iqmp      *big.Int // IQMP = Inverse Q Mod P
+	P         *big.Int
+	Q         *big.Int
+	Comments  string
+}
+
+type dsaCertMsg struct {
+	Type      string `sshtype:"17"`
+	CertBytes []byte
+	X         *big.Int
+	Comments  string
+}
+
+type ecdsaCertMsg struct {
+	Type      string `sshtype:"17"`
+	CertBytes []byte
+	D         *big.Int
+	Comments  string
+}
+
+// Insert adds a private key to the agent. If a certificate is given,
+// that certificate is added instead as public key.
+func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error {
+	if cert == nil {
+		return c.insertKey(s, comment)
+	} else {
+		return c.insertCert(s, cert, comment)
+	}
+}
+
+func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error {
+	var req []byte
+	switch k := s.(type) {
+	case *rsa.PrivateKey:
+		if len(k.Primes) != 2 {
+			return fmt.Errorf("ssh: unsupported RSA key with %d primes", len(k.Primes))
+		}
+		k.Precompute()
+		req = ssh.Marshal(rsaCertMsg{
+			Type:      cert.Type(),
+			CertBytes: cert.Marshal(),
+			D:         k.D,
+			Iqmp:      k.Precomputed.Qinv,
+			P:         k.Primes[0],
+			Q:         k.Primes[1],
+			Comments:  comment,
+		})
+	case *dsa.PrivateKey:
+		req = ssh.Marshal(dsaCertMsg{
+			Type:      cert.Type(),
+			CertBytes: cert.Marshal(),
+			X:         k.X,
+			Comments:  comment,
+		})
+	case *ecdsa.PrivateKey:
+		req = ssh.Marshal(ecdsaCertMsg{
+			Type:      cert.Type(),
+			CertBytes: cert.Marshal(),
+			D:         k.D,
+			Comments:  comment,
+		})
+	default:
+		return fmt.Errorf("ssh: unsupported key type %T", s)
+	}
+
+	signer, err := ssh.NewSignerFromKey(s)
+	if err != nil {
+		return err
+	}
+	if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+		return errors.New("ssh: signer and cert have different public key")
+	}
+
+	resp, err := c.call(req)
+	if err != nil {
+		return err
+	}
+	if _, ok := resp.(*successAgentMsg); ok {
+		return nil
+	}
+	return errors.New("ssh: failure")
+}
+
+// Signers provides a callback for client authentication.
+func (c *client) Signers() ([]ssh.Signer, error) {
+	keys, err := c.List()
+	if err != nil {
+		return nil, err
+	}
+
+	var result []ssh.Signer
+	for _, k := range keys {
+		result = append(result, &agentKeyringSigner{c, k})
+	}
+	return result, nil
+}
+
+type agentKeyringSigner struct {
+	agent *client
+	pub   ssh.PublicKey
+}
+
+func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
+	return s.pub
+}
+
+func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
+	// The agent has its own entropy source, so the rand argument is ignored.
+	return s.agent.Sign(s.pub, data)
+}
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
new file mode 100644
index 0000000..aa99e27
--- /dev/null
+++ b/ssh/agent/client_test.go
@@ -0,0 +1,270 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+	"bytes"
+	"crypto/rand"
+	"errors"
+	"net"
+	"os"
+	"os/exec"
+	"strconv"
+	"testing"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
+	bin, err := exec.LookPath("ssh-agent")
+	if err != nil {
+		t.Skip("could not find ssh-agent")
+	}
+
+	cmd := exec.Command(bin, "-s")
+	out, err := cmd.Output()
+	if err != nil {
+		t.Fatalf("cmd.Output: %v", err)
+	}
+
+	/* Output looks like:
+
+		   SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
+	           SSH_AGENT_PID=15542; export SSH_AGENT_PID;
+	           echo Agent pid 15542;
+	*/
+	fields := bytes.Split(out, []byte(";"))
+	line := bytes.SplitN(fields[0], []byte("="), 2)
+	line[0] = bytes.TrimLeft(line[0], "\n")
+	if string(line[0]) != "SSH_AUTH_SOCK" {
+		t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
+	}
+	socket = string(line[1])
+
+	line = bytes.SplitN(fields[2], []byte("="), 2)
+	line[0] = bytes.TrimLeft(line[0], "\n")
+	if string(line[0]) != "SSH_AGENT_PID" {
+		t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
+	}
+	pidStr := line[1]
+	pid, err := strconv.Atoi(string(pidStr))
+	if err != nil {
+		t.Fatalf("Atoi(%q): %v", pidStr, err)
+	}
+
+	conn, err := net.Dial("unix", string(socket))
+	if err != nil {
+		t.Fatalf("net.Dial: %v", err)
+	}
+
+	ac := NewClient(conn)
+	return ac, socket, func() {
+		proc, _ := os.FindProcess(pid)
+		if proc != nil {
+			proc.Kill()
+		}
+		conn.Close()
+	}
+}
+
+func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) {
+	agent, _, cleanup := startAgent(t)
+	defer cleanup()
+
+	testAgentInterface(t, agent, key, cert)
+}
+
+func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) {
+	signer, err := ssh.NewSignerFromKey(key)
+	if err != nil {
+		t.Fatalf("NewSignerFromKey: %v", err)
+	}
+	// The agent should start up empty.
+	if keys, err := agent.List(); err != nil {
+		t.Fatalf("RequestIdentities: %v", err)
+	} else if len(keys) > 0 {
+		t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
+	}
+
+	// Attempt to insert the key, with certificate if specified.
+	var pubKey ssh.PublicKey
+	if cert != nil {
+		err = agent.Add(key, cert, "comment")
+		pubKey = cert
+	} else {
+		err = agent.Add(key, nil, "comment")
+		pubKey = signer.PublicKey()
+	}
+	if err != nil {
+		t.Fatalf("insert: %v", err)
+	}
+
+	// Did the key get inserted successfully?
+	if keys, err := agent.List(); err != nil {
+		t.Fatalf("List: %v", err)
+	} else if len(keys) != 1 {
+		t.Fatalf("got %v, want 1 key", keys)
+	} else if keys[0].Comment != "comment" {
+		t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
+	} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
+		t.Fatalf("key mismatch")
+	}
+
+	// Can the agent make a valid signature?
+	data := []byte("hello")
+	sig, err := agent.Sign(pubKey, data)
+	if err != nil {
+		t.Fatalf("Sign: %v", err)
+	}
+
+	if err := pubKey.Verify(data, sig); err != nil {
+		t.Fatalf("key signature Verify: %v", err)
+	}
+}
+
+func TestAgent(t *testing.T) {
+	for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
+		t.Log(keyType)
+		testAgent(t, testPrivateKeys[keyType], nil)
+	}
+}
+
+func TestCert(t *testing.T) {
+	cert := &ssh.Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: ssh.CertTimeInfinity,
+		CertType:    ssh.UserCert,
+	}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+
+	testAgent(t, testPrivateKeys["rsa"], cert)
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return nil, nil, err
+	}
+	defer listener.Close()
+	c1, err := net.Dial("tcp", listener.Addr().String())
+	if err != nil {
+		return nil, nil, err
+	}
+
+	c2, err := listener.Accept()
+	if err != nil {
+		c1.Close()
+		return nil, nil, err
+	}
+
+	return c1, c2, nil
+}
+
+func TestAuth(t *testing.T) {
+	a, b, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+
+	defer a.Close()
+	defer b.Close()
+
+	agent, _, cleanup := startAgent(t)
+	defer cleanup()
+
+	if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil {
+		t.Errorf("Add: %v", err)
+	}
+
+	serverConf := ssh.ServerConfig{}
+	serverConf.AddHostKey(testSigners["rsa"])
+	serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+		if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+			return nil, nil
+		}
+
+		return nil, errors.New("pubkey rejected")
+	}
+
+	go func() {
+		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
+		if err != nil {
+			t.Fatalf("Server: %v", err)
+		}
+		conn.Close()
+	}()
+
+	conf := ssh.ClientConfig{}
+	conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
+	conn, _, _, err := ssh.NewClientConn(b, "", &conf)
+	if err != nil {
+		t.Fatalf("NewClientConn: %v", err)
+	}
+	conn.Close()
+}
+
+func TestLockClient(t *testing.T) {
+	agent, _, cleanup := startAgent(t)
+	defer cleanup()
+	testLockAgent(agent, t)
+}
+
+func testLockAgent(agent Agent, t *testing.T) {
+	if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil {
+		t.Errorf("Add: %v", err)
+	}
+	if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil {
+		t.Errorf("Add: %v", err)
+	}
+	if keys, err := agent.List(); err != nil {
+		t.Errorf("List: %v", err)
+	} else if len(keys) != 2 {
+		t.Errorf("Want 2 keys, got %v", keys)
+	}
+
+	passphrase := []byte("secret")
+	if err := agent.Lock(passphrase); err != nil {
+		t.Errorf("Lock: %v", err)
+	}
+
+	if keys, err := agent.List(); err != nil {
+		t.Errorf("List: %v", err)
+	} else if len(keys) != 0 {
+		t.Errorf("Want 0 keys, got %v", keys)
+	}
+
+	signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
+	if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
+		t.Fatalf("Sign did not fail")
+	}
+
+	if err := agent.Remove(signer.PublicKey()); err == nil {
+		t.Fatalf("Remove did not fail")
+	}
+
+	if err := agent.RemoveAll(); err == nil {
+		t.Fatalf("RemoveAll did not fail")
+	}
+
+	if err := agent.Unlock(nil); err == nil {
+		t.Errorf("Unlock with wrong passphrase succeeded")
+	}
+	if err := agent.Unlock(passphrase); err != nil {
+		t.Errorf("Unlock: %v", err)
+	}
+
+	if err := agent.Remove(signer.PublicKey()); err != nil {
+		t.Fatalf("Remove: %v", err)
+	}
+
+	if keys, err := agent.List(); err != nil {
+		t.Errorf("List: %v", err)
+	} else if len(keys) != 1 {
+		t.Errorf("Want 1 keys, got %v", keys)
+	}
+}
diff --git a/ssh/agent/forward.go b/ssh/agent/forward.go
new file mode 100644
index 0000000..dd45c3e
--- /dev/null
+++ b/ssh/agent/forward.go
@@ -0,0 +1,103 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+	"errors"
+	"io"
+	"net"
+	"sync"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+// RequestAgentForwarding sets up agent forwarding for the session.
+// SetupForwardKeyring or SetupForwardAgent should be called to route
+// the authentication requests.
+func RequestAgentForwarding(session *ssh.Session) error {
+	ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
+	if err != nil {
+		return err
+	}
+	if !ok {
+		return errors.New("forwarding request denied")
+	}
+	return nil
+}
+
+// ForwardToAgent routes authentication requests to the given keyring.
+func ForwardToAgent(client *ssh.Client, keyring Agent) error {
+	channels := client.HandleChannelOpen(channelType)
+	if channels == nil {
+		return errors.New("agent: already have handler for " + channelType)
+	}
+
+	go func() {
+		for ch := range channels {
+			channel, reqs, err := ch.Accept()
+			if err != nil {
+				continue
+			}
+			go ssh.DiscardRequests(reqs)
+			go func() {
+				ServeAgent(keyring, channel)
+				channel.Close()
+			}()
+		}
+	}()
+	return nil
+}
+
+const channelType = "auth-agent@openssh.com"
+
+// ForwardToRemote routes authentication requests to the ssh-agent
+// process serving on the given unix socket.
+func ForwardToRemote(client *ssh.Client, addr string) error {
+	channels := client.HandleChannelOpen(channelType)
+	if channels == nil {
+		return errors.New("agent: already have handler for " + channelType)
+	}
+	conn, err := net.Dial("unix", addr)
+	if err != nil {
+		return err
+	}
+	conn.Close()
+
+	go func() {
+		for ch := range channels {
+			channel, reqs, err := ch.Accept()
+			if err != nil {
+				continue
+			}
+			go ssh.DiscardRequests(reqs)
+			go forwardUnixSocket(channel, addr)
+		}
+	}()
+	return nil
+}
+
+func forwardUnixSocket(channel ssh.Channel, addr string) {
+	conn, err := net.Dial("unix", addr)
+	if err != nil {
+		return
+	}
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+	go func() {
+		io.Copy(conn, channel)
+		conn.(*net.UnixConn).CloseWrite()
+		wg.Done()
+	}()
+	go func() {
+		io.Copy(channel, conn)
+		channel.CloseWrite()
+		wg.Done()
+	}()
+
+	wg.Wait()
+	conn.Close()
+	channel.Close()
+}
diff --git a/ssh/agent/keyring.go b/ssh/agent/keyring.go
new file mode 100644
index 0000000..ecfa66f
--- /dev/null
+++ b/ssh/agent/keyring.go
@@ -0,0 +1,183 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+	"bytes"
+	"crypto/rand"
+	"crypto/subtle"
+	"errors"
+	"fmt"
+	"sync"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+type privKey struct {
+	signer  ssh.Signer
+	comment string
+}
+
+type keyring struct {
+	mu   sync.Mutex
+	keys []privKey
+
+	locked     bool
+	passphrase []byte
+}
+
+var errLocked = errors.New("agent: locked")
+
+// NewKeyring returns an Agent that holds keys in memory.  It is safe
+// for concurrent use by multiple goroutines.
+func NewKeyring() Agent {
+	return &keyring{}
+}
+
+// RemoveAll removes all identities.
+func (r *keyring) RemoveAll() error {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		return errLocked
+	}
+
+	r.keys = nil
+	return nil
+}
+
+// Remove removes all identities with the given public key.
+func (r *keyring) Remove(key ssh.PublicKey) error {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		return errLocked
+	}
+
+	want := key.Marshal()
+	found := false
+	for i := 0; i < len(r.keys); {
+		if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
+			found = true
+			r.keys[i] = r.keys[len(r.keys)-1]
+			r.keys = r.keys[len(r.keys)-1:]
+			continue
+		} else {
+			i++
+		}
+	}
+
+	if !found {
+		return errors.New("agent: key not found")
+	}
+	return nil
+}
+
+// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
+func (r *keyring) Lock(passphrase []byte) error {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		return errLocked
+	}
+
+	r.locked = true
+	r.passphrase = passphrase
+	return nil
+}
+
+// Unlock undoes the effect of Lock
+func (r *keyring) Unlock(passphrase []byte) error {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if !r.locked {
+		return errors.New("agent: not locked")
+	}
+	if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
+		return fmt.Errorf("agent: incorrect passphrase")
+	}
+
+	r.locked = false
+	r.passphrase = nil
+	return nil
+}
+
+// List returns the identities known to the agent.
+func (r *keyring) List() ([]*Key, error) {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		// section 2.7: locked agents return empty.
+		return nil, nil
+	}
+
+	var ids []*Key
+	for _, k := range r.keys {
+		pub := k.signer.PublicKey()
+		ids = append(ids, &Key{
+			Format:  pub.Type(),
+			Blob:    pub.Marshal(),
+			Comment: k.comment})
+	}
+	return ids, nil
+}
+
+// Insert adds a private key to the keyring. If a certificate
+// is given, that certificate is added as public key.
+func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		return errLocked
+	}
+	signer, err := ssh.NewSignerFromKey(priv)
+
+	if err != nil {
+		return err
+	}
+
+	if cert != nil {
+		signer, err = ssh.NewCertSigner(cert, signer)
+		if err != nil {
+			return err
+		}
+	}
+
+	r.keys = append(r.keys, privKey{signer, comment})
+
+	return nil
+}
+
+// Sign returns a signature for the data.
+func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		return nil, errLocked
+	}
+
+	wanted := key.Marshal()
+	for _, k := range r.keys {
+		if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
+			return k.signer.Sign(rand.Reader, data)
+		}
+	}
+	return nil, errors.New("not found")
+}
+
+// Signers returns signers for all the known keys.
+func (r *keyring) Signers() ([]ssh.Signer, error) {
+	r.mu.Lock()
+	defer r.mu.Unlock()
+	if r.locked {
+		return nil, errLocked
+	}
+
+	s := make([]ssh.Signer, len(r.keys))
+	for _, k := range r.keys {
+		s = append(s, k.signer)
+	}
+	return s, nil
+}
diff --git a/ssh/agent/server.go b/ssh/agent/server.go
new file mode 100644
index 0000000..2d55dc9
--- /dev/null
+++ b/ssh/agent/server.go
@@ -0,0 +1,209 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+	"crypto/rsa"
+	"encoding/binary"
+	"fmt"
+	"io"
+	"log"
+	"math/big"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+// Server wraps an Agent and uses it to implement the agent side of
+// the SSH-agent, wire protocol.
+type server struct {
+	agent Agent
+}
+
+func (s *server) processRequestBytes(reqData []byte) []byte {
+	rep, err := s.processRequest(reqData)
+	if err != nil {
+		if err != errLocked {
+			// TODO(hanwen): provide better logging interface?
+			log.Printf("agent %d: %v", reqData[0], err)
+		}
+		return []byte{agentFailure}
+	}
+
+	if err == nil && rep == nil {
+		return []byte{agentSuccess}
+	}
+
+	return ssh.Marshal(rep)
+}
+
+func marshalKey(k *Key) []byte {
+	var record struct {
+		Blob    []byte
+		Comment string
+	}
+	record.Blob = k.Marshal()
+	record.Comment = k.Comment
+
+	return ssh.Marshal(&record)
+}
+
+type agentV1IdentityMsg struct {
+	Numkeys uint32 `sshtype:"2"`
+}
+
+type agentRemoveIdentityMsg struct {
+	KeyBlob []byte `sshtype:"18"`
+}
+
+type agentLockMsg struct {
+	Passphrase []byte `sshtype:"22"`
+}
+
+type agentUnlockMsg struct {
+	Passphrase []byte `sshtype:"23"`
+}
+
+func (s *server) processRequest(data []byte) (interface{}, error) {
+	switch data[0] {
+	case agentRequestV1Identities:
+		return &agentV1IdentityMsg{0}, nil
+	case agentRemoveIdentity:
+		var req agentRemoveIdentityMsg
+		if err := ssh.Unmarshal(data, &req); err != nil {
+			return nil, err
+		}
+
+		var wk wireKey
+		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
+			return nil, err
+		}
+
+		return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
+
+	case agentRemoveAllIdentities:
+		return nil, s.agent.RemoveAll()
+
+	case agentLock:
+		var req agentLockMsg
+		if err := ssh.Unmarshal(data, &req); err != nil {
+			return nil, err
+		}
+
+		return nil, s.agent.Lock(req.Passphrase)
+
+	case agentUnlock:
+		var req agentLockMsg
+		if err := ssh.Unmarshal(data, &req); err != nil {
+			return nil, err
+		}
+		return nil, s.agent.Unlock(req.Passphrase)
+
+	case agentSignRequest:
+		var req signRequestAgentMsg
+		if err := ssh.Unmarshal(data, &req); err != nil {
+			return nil, err
+		}
+
+		var wk wireKey
+		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
+			return nil, err
+		}
+
+		k := &Key{
+			Format: wk.Format,
+			Blob:   req.KeyBlob,
+		}
+
+		sig, err := s.agent.Sign(k, req.Data) //  TODO(hanwen): flags.
+		if err != nil {
+			return nil, err
+		}
+		return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
+	case agentRequestIdentities:
+		keys, err := s.agent.List()
+		if err != nil {
+			return nil, err
+		}
+
+		rep := identitiesAnswerAgentMsg{
+			NumKeys: uint32(len(keys)),
+		}
+		for _, k := range keys {
+			rep.Keys = append(rep.Keys, marshalKey(k)...)
+		}
+		return rep, nil
+	case agentAddIdentity:
+		return nil, s.insertIdentity(data)
+	}
+
+	return nil, fmt.Errorf("unknown opcode %d", data[0])
+}
+
+func (s *server) insertIdentity(req []byte) error {
+	var record struct {
+		Type string `sshtype:"17"`
+		Rest []byte `ssh:"rest"`
+	}
+	if err := ssh.Unmarshal(req, &record); err != nil {
+		return err
+	}
+
+	switch record.Type {
+	case ssh.KeyAlgoRSA:
+		var k rsaKeyMsg
+		if err := ssh.Unmarshal(req, &k); err != nil {
+			return err
+		}
+
+		priv := rsa.PrivateKey{
+			PublicKey: rsa.PublicKey{
+				E: int(k.E.Int64()),
+				N: k.N,
+			},
+			D:      k.D,
+			Primes: []*big.Int{k.P, k.Q},
+		}
+		priv.Precompute()
+
+		return s.agent.Add(&priv, nil, k.Comments)
+	}
+	return fmt.Errorf("not implemented: %s", record.Type)
+}
+
+// ServeAgent serves the agent protocol on the given connection. It
+// returns when an I/O error occurs.
+func ServeAgent(agent Agent, c io.ReadWriter) error {
+	s := &server{agent}
+
+	var length [4]byte
+	for {
+		if _, err := io.ReadFull(c, length[:]); err != nil {
+			return err
+		}
+		l := binary.BigEndian.Uint32(length[:])
+		if l > maxAgentResponseBytes {
+			// We also cap requests.
+			return fmt.Errorf("agent: request too large: %d", l)
+		}
+
+		req := make([]byte, l)
+		if _, err := io.ReadFull(c, req); err != nil {
+			return err
+		}
+
+		repData := s.processRequestBytes(req)
+		if len(repData) > maxAgentResponseBytes {
+			return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
+		}
+
+		binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
+		if _, err := c.Write(length[:]); err != nil {
+			return err
+		}
+		if _, err := c.Write(repData); err != nil {
+			return err
+		}
+	}
+}
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
new file mode 100644
index 0000000..ad2996b
--- /dev/null
+++ b/ssh/agent/server_test.go
@@ -0,0 +1,77 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+	"testing"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+func TestServer(t *testing.T) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+	client := NewClient(c1)
+
+	go ServeAgent(NewKeyring(), c2)
+
+	testAgentInterface(t, client, testPrivateKeys["rsa"], nil)
+}
+
+func TestLockServer(t *testing.T) {
+	testLockAgent(NewKeyring(), t)
+}
+
+func TestSetupForwardAgent(t *testing.T) {
+	a, b, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+
+	defer a.Close()
+	defer b.Close()
+
+	_, socket, cleanup := startAgent(t)
+	defer cleanup()
+
+	serverConf := ssh.ServerConfig{
+		NoClientAuth: true,
+	}
+	serverConf.AddHostKey(testSigners["rsa"])
+	incoming := make(chan *ssh.ServerConn, 1)
+	go func() {
+		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
+		if err != nil {
+			t.Fatalf("Server: %v", err)
+		}
+		incoming <- conn
+	}()
+
+	conf := ssh.ClientConfig{}
+	conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
+	if err != nil {
+		t.Fatalf("NewClientConn: %v", err)
+	}
+	client := ssh.NewClient(conn, chans, reqs)
+
+	if err := ForwardToRemote(client, socket); err != nil {
+		t.Fatalf("SetupForwardAgent: %v", err)
+	}
+
+	server := <-incoming
+	ch, reqs, err := server.OpenChannel(channelType, nil)
+	if err != nil {
+		t.Fatalf("OpenChannel(%q): %v", channelType, err)
+	}
+	go ssh.DiscardRequests(reqs)
+
+	agentClient := NewClient(ch)
+	testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil)
+	conn.Close()
+}
diff --git a/ssh/agent/testdata_test.go b/ssh/agent/testdata_test.go
new file mode 100644
index 0000000..6bb75a9
--- /dev/null
+++ b/ssh/agent/testdata_test.go
@@ -0,0 +1,64 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
+// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
+// instances.
+
+package agent
+
+import (
+	"crypto/rand"
+	"fmt"
+
+	"code.google.com/p/go.crypto/ssh"
+	"code.google.com/p/go.crypto/ssh/testdata"
+)
+
+var (
+	testPrivateKeys map[string]interface{}
+	testSigners     map[string]ssh.Signer
+	testPublicKeys  map[string]ssh.PublicKey
+)
+
+func init() {
+	var err error
+
+	n := len(testdata.PEMBytes)
+	testPrivateKeys = make(map[string]interface{}, n)
+	testSigners = make(map[string]ssh.Signer, n)
+	testPublicKeys = make(map[string]ssh.PublicKey, n)
+	for t, k := range testdata.PEMBytes {
+		testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
+		if err != nil {
+			panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
+		}
+		testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
+		if err != nil {
+			panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
+		}
+		testPublicKeys[t] = testSigners[t].PublicKey()
+	}
+
+	// Create a cert and sign it for use in tests.
+	testCert := &ssh.Certificate{
+		Nonce:           []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+		ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+		ValidAfter:      0,                              // unix epoch
+		ValidBefore:     ssh.CertTimeInfinity,           // The end of currently representable time.
+		Reserved:        []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+		Key:             testPublicKeys["ecdsa"],
+		SignatureKey:    testPublicKeys["rsa"],
+		Permissions: ssh.Permissions{
+			CriticalOptions: map[string]string{},
+			Extensions:      map[string]string{},
+		},
+	}
+	testCert.SignCert(rand.Reader, testSigners["rsa"])
+	testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
+	testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
+	if err != nil {
+		panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
+	}
+}
diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go
new file mode 100644
index 0000000..d9f7eb9
--- /dev/null
+++ b/ssh/benchmark_test.go
@@ -0,0 +1,122 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"errors"
+	"io"
+	"net"
+	"testing"
+)
+
+type server struct {
+	*ServerConn
+	chans <-chan NewChannel
+}
+
+func newServer(c net.Conn, conf *ServerConfig) (*server, error) {
+	sconn, chans, reqs, err := NewServerConn(c, conf)
+	if err != nil {
+		return nil, err
+	}
+	go DiscardRequests(reqs)
+	return &server{sconn, chans}, nil
+}
+
+func (s *server) Accept() (NewChannel, error) {
+	n, ok := <-s.chans
+	if !ok {
+		return nil, io.EOF
+	}
+	return n, nil
+}
+
+func sshPipe() (Conn, *server, error) {
+	c1, c2, err := netPipe()
+	if err != nil {
+		return nil, nil, err
+	}
+
+	clientConf := ClientConfig{
+		User: "user",
+	}
+	serverConf := ServerConfig{
+		NoClientAuth: true,
+	}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	done := make(chan *server, 1)
+	go func() {
+		server, err := newServer(c2, &serverConf)
+		if err != nil {
+			done <- nil
+		}
+		done <- server
+	}()
+
+	client, _, reqs, err := NewClientConn(c1, "", &clientConf)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	server := <-done
+	if server == nil {
+		return nil, nil, errors.New("server handshake failed.")
+	}
+	go DiscardRequests(reqs)
+
+	return client, server, nil
+}
+
+func BenchmarkEndToEnd(b *testing.B) {
+	b.StopTimer()
+
+	client, server, err := sshPipe()
+	if err != nil {
+		b.Fatalf("sshPipe: %v", err)
+	}
+
+	defer client.Close()
+	defer server.Close()
+
+	size := (1 << 20)
+	input := make([]byte, size)
+	output := make([]byte, size)
+	b.SetBytes(int64(size))
+	done := make(chan int, 1)
+
+	go func() {
+		newCh, err := server.Accept()
+		if err != nil {
+			b.Fatalf("Client: %v", err)
+		}
+		ch, incoming, err := newCh.Accept()
+		go DiscardRequests(incoming)
+		for i := 0; i < b.N; i++ {
+			if _, err := io.ReadFull(ch, output); err != nil {
+				b.Fatalf("ReadFull: %v", err)
+			}
+		}
+		ch.Close()
+		done <- 1
+	}()
+
+	ch, in, err := client.OpenChannel("speed", nil)
+	if err != nil {
+		b.Fatalf("OpenChannel: %v", err)
+	}
+	go DiscardRequests(in)
+
+	b.ResetTimer()
+	b.StartTimer()
+	for i := 0; i < b.N; i++ {
+		if _, err := ch.Write(input); err != nil {
+			b.Fatalf("WriteFull: %v", err)
+		}
+	}
+	ch.Close()
+	b.StopTimer()
+
+	<-done
+}
diff --git a/ssh/buffer.go b/ssh/buffer.go
index 601dad3..6931b51 100644
--- a/ssh/buffer.go
+++ b/ssh/buffer.go
@@ -43,29 +43,29 @@
 // buf must not be modified after the call to write.
 func (b *buffer) write(buf []byte) {
 	b.Cond.L.Lock()
-	defer b.Cond.L.Unlock()
 	e := &element{buf: buf}
 	b.tail.next = e
 	b.tail = e
 	b.Cond.Signal()
+	b.Cond.L.Unlock()
 }
 
 // eof closes the buffer. Reads from the buffer once all
 // the data has been consumed will receive os.EOF.
 func (b *buffer) eof() error {
 	b.Cond.L.Lock()
-	defer b.Cond.L.Unlock()
 	b.closed = true
 	b.Cond.Signal()
+	b.Cond.L.Unlock()
 	return nil
 }
 
-// Read reads data from the internal buffer in buf.
-// Reads will block if no data is available, or until
-// the buffer is closed.
+// Read reads data from the internal buffer in buf.  Reads will block
+// if no data is available, or until the buffer is closed.
 func (b *buffer) Read(buf []byte) (n int, err error) {
 	b.Cond.L.Lock()
 	defer b.Cond.L.Unlock()
+
 	for len(buf) > 0 {
 		// if there is data in b.head, copy it
 		if len(b.head.buf) > 0 {
@@ -79,10 +79,12 @@
 			b.head = b.head.next
 			continue
 		}
+
 		// if at least one byte has been copied, return
 		if n > 0 {
 			break
 		}
+
 		// if nothing was read, and there is nothing outstanding
 		// check to see if the buffer is closed.
 		if b.closed {
diff --git a/ssh/buffer_test.go b/ssh/buffer_test.go
index 135c4ae..d5781cb 100644
--- a/ssh/buffer_test.go
+++ b/ssh/buffer_test.go
@@ -9,33 +9,33 @@
 	"testing"
 )
 
-var BYTES = []byte("abcdefghijklmnopqrstuvwxyz")
+var alphabet = []byte("abcdefghijklmnopqrstuvwxyz")
 
 func TestBufferReadwrite(t *testing.T) {
 	b := newBuffer()
-	b.write(BYTES[:10])
+	b.write(alphabet[:10])
 	r, _ := b.Read(make([]byte, 10))
 	if r != 10 {
 		t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
 	}
 
 	b = newBuffer()
-	b.write(BYTES[:5])
+	b.write(alphabet[:5])
 	r, _ = b.Read(make([]byte, 10))
 	if r != 5 {
 		t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
 	}
 
 	b = newBuffer()
-	b.write(BYTES[:10])
+	b.write(alphabet[:10])
 	r, _ = b.Read(make([]byte, 5))
 	if r != 5 {
 		t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
 	}
 
 	b = newBuffer()
-	b.write(BYTES[:5])
-	b.write(BYTES[5:15])
+	b.write(alphabet[:5])
+	b.write(alphabet[5:15])
 	r, _ = b.Read(make([]byte, 10))
 	r2, _ := b.Read(make([]byte, 10))
 	if r != 10 || r2 != 5 || 15 != r+r2 {
@@ -45,14 +45,14 @@
 
 func TestBufferClose(t *testing.T) {
 	b := newBuffer()
-	b.write(BYTES[:10])
+	b.write(alphabet[:10])
 	b.eof()
 	_, err := b.Read(make([]byte, 5))
 	if err != nil {
 		t.Fatal("expected read of 5 to not return EOF")
 	}
 	b = newBuffer()
-	b.write(BYTES[:10])
+	b.write(alphabet[:10])
 	b.eof()
 	r, err := b.Read(make([]byte, 5))
 	r2, err2 := b.Read(make([]byte, 10))
@@ -61,7 +61,7 @@
 	}
 
 	b = newBuffer()
-	b.write(BYTES[:10])
+	b.write(alphabet[:10])
 	b.eof()
 	r, err = b.Read(make([]byte, 5))
 	r2, err2 = b.Read(make([]byte, 10))
diff --git a/ssh/certs.go b/ssh/certs.go
index d958f31..9962ff0 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -5,6 +5,12 @@
 package ssh
 
 import (
+	"bytes"
+	"errors"
+	"fmt"
+	"io"
+	"net"
+	"sort"
 	"time"
 )
 
@@ -18,67 +24,348 @@
 	CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
 )
 
-// Certificate types are used to specify whether a certificate is for identification
-// of a user or a host.  Current identities are defined in [PROTOCOL.certkeys].
+// Certificate types distinguish between host and user
+// certificates. The values can be set in the CertType field of
+// Certificate.
 const (
 	UserCert = 1
 	HostCert = 2
 )
 
-type signature struct {
+// Signature represents a cryptographic signature.
+type Signature struct {
 	Format string
 	Blob   []byte
 }
 
-type tuple struct {
-	Name string
-	Data string
-}
+// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
+// a certificate does not expire.
+const CertTimeInfinity = 1<<64 - 1
 
-const (
-	maxUint64 = 1<<64 - 1
-	maxInt64  = 1<<63 - 1
-)
-
-// CertTime represents an unsigned 64-bit time value in seconds starting from
-// UNIX epoch.  We use CertTime instead of time.Time in order to properly handle
-// the "infinite" time value ^0, which would become negative when expressed as
-// an int64.
-type CertTime uint64
-
-func (ct CertTime) Time() time.Time {
-	if ct > maxInt64 {
-		return time.Unix(maxInt64, 0)
-	}
-	return time.Unix(int64(ct), 0)
-}
-
-func (ct CertTime) IsInfinite() bool {
-	return ct == maxUint64
-}
-
-// An OpenSSHCertV01 represents an OpenSSH certificate as defined in
+// An Certificate represents an OpenSSH certificate as defined in
 // [PROTOCOL.certkeys]?rev=1.8.
-type OpenSSHCertV01 struct {
-	Nonce                   []byte
-	Key                     PublicKey
-	Serial                  uint64
-	Type                    uint32
-	KeyId                   string
-	ValidPrincipals         []string
-	ValidAfter, ValidBefore CertTime
-	CriticalOptions         []tuple
-	Extensions              []tuple
-	Reserved                []byte
-	SignatureKey            PublicKey
-	Signature               *signature
+type Certificate struct {
+	Nonce           []byte
+	Key             PublicKey
+	Serial          uint64
+	CertType        uint32
+	KeyId           string
+	ValidPrincipals []string
+	ValidAfter      uint64
+	ValidBefore     uint64
+	Permissions
+	Reserved     []byte
+	SignatureKey PublicKey
+	Signature    *Signature
 }
 
-// validateOpenSSHCertV01Signature uses the cert's SignatureKey to verify that
-// the cert's Signature.Blob is the result of signing the cert bytes starting
-// from the algorithm string and going up to and including the SignatureKey.
-func validateOpenSSHCertV01Signature(cert *OpenSSHCertV01) bool {
-	return cert.SignatureKey.Verify(cert.BytesForSigning(), cert.Signature.Blob)
+// genericCertData holds the key-independent part of the certificate data.
+// Overall, certificates contain an nonce, public key fields and
+// key-independent fields.
+type genericCertData struct {
+	Serial          uint64
+	CertType        uint32
+	KeyId           string
+	ValidPrincipals []byte
+	ValidAfter      uint64
+	ValidBefore     uint64
+	CriticalOptions []byte
+	Extensions      []byte
+	Reserved        []byte
+	SignatureKey    []byte
+	Signature       []byte
+}
+
+func marshalStringList(namelist []string) []byte {
+	var to []byte
+	for _, name := range namelist {
+		s := struct{ N string }{name}
+		to = append(to, Marshal(&s)...)
+	}
+	return to
+}
+
+func marshalTuples(tups map[string]string) []byte {
+	keys := make([]string, 0, len(tups))
+	for k := range tups {
+		keys = append(keys, k)
+	}
+	sort.Strings(keys)
+
+	var r []byte
+	for _, k := range keys {
+		s := struct{ K, V string }{k, tups[k]}
+		r = append(r, Marshal(&s)...)
+	}
+	return r
+}
+
+func parseTuples(in []byte) (map[string]string, error) {
+	tups := map[string]string{}
+	var lastKey string
+	var haveLastKey bool
+
+	for len(in) > 0 {
+		nameBytes, rest, ok := parseString(in)
+		if !ok {
+			return nil, errShortRead
+		}
+		data, rest, ok := parseString(rest)
+		if !ok {
+			return nil, errShortRead
+		}
+		name := string(nameBytes)
+
+		// according to [PROTOCOL.certkeys], the names must be in
+		// lexical order.
+		if haveLastKey && name <= lastKey {
+			return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
+		}
+		lastKey, haveLastKey = name, true
+
+		tups[name] = string(data)
+		in = rest
+	}
+	return tups, nil
+}
+
+func parseCert(in []byte, privAlgo string) (*Certificate, error) {
+	nonce, rest, ok := parseString(in)
+	if !ok {
+		return nil, errShortRead
+	}
+
+	key, rest, err := parsePubKey(rest, privAlgo)
+	if err != nil {
+		return nil, err
+	}
+
+	var g genericCertData
+	if err := Unmarshal(rest, &g); err != nil {
+		return nil, err
+	}
+
+	c := &Certificate{
+		Nonce:       nonce,
+		Key:         key,
+		Serial:      g.Serial,
+		CertType:    g.CertType,
+		KeyId:       g.KeyId,
+		ValidAfter:  g.ValidAfter,
+		ValidBefore: g.ValidBefore,
+	}
+
+	for principals := g.ValidPrincipals; len(principals) > 0; {
+		principal, rest, ok := parseString(principals)
+		if !ok {
+			return nil, errShortRead
+		}
+		c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
+		principals = rest
+	}
+
+	c.CriticalOptions, err = parseTuples(g.CriticalOptions)
+	if err != nil {
+		return nil, err
+	}
+	c.Extensions, err = parseTuples(g.Extensions)
+	if err != nil {
+		return nil, err
+	}
+	c.Reserved = g.Reserved
+	k, err := ParsePublicKey(g.SignatureKey)
+	if err != nil {
+		return nil, err
+	}
+
+	c.SignatureKey = k
+	c.Signature, rest, ok = parseSignatureBody(g.Signature)
+	if !ok || len(rest) > 0 {
+		return nil, errors.New("ssh: signature parse error")
+	}
+
+	return c, nil
+}
+
+type openSSHCertSigner struct {
+	pub    *Certificate
+	signer Signer
+}
+
+// NewCertSigner returns a Signer that signs with the given Certificate, whose
+// private key is held by signer. It returns an error if the public key in cert
+// doesn't match the key used by signer.
+func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
+	if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+		return nil, errors.New("ssh: signer and cert have different public key")
+	}
+
+	return &openSSHCertSigner{cert, signer}, nil
+}
+
+func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
+	return s.signer.Sign(rand, data)
+}
+
+func (s *openSSHCertSigner) PublicKey() PublicKey {
+	return s.pub
+}
+
+const sourceAddressCriticalOption = "source-address"
+
+// CertChecker does the work of verifying a certificate. Its methods
+// can be plugged into ClientConfig.HostKeyCallback and
+// ServerConfig.PublicKeyCallback. For the CertChecker to work,
+// minimally, the IsAuthority callback should be set.
+type CertChecker struct {
+	// SupportedCriticalOptions lists the CriticalOptions that the
+	// server application layer understands. These are only used
+	// for user certificates.
+	SupportedCriticalOptions []string
+
+	// IsAuthority should return true if the key is recognized as
+	// an authority. This allows for certificates to be signed by other
+	// certificates.
+	IsAuthority func(auth PublicKey) bool
+
+	// Clock is used for verifying time stamps. If nil, time.Now
+	// is used.
+	Clock func() time.Time
+
+	// UserKeyFallback is called when CertChecker.Authenticate encounters a
+	// public key that is not a certificate. It must implement validation
+	// of user keys or else, if nil, all such keys are rejected.
+	UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
+
+	// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
+	// public key that is not a certificate. It must implement host key
+	// validation or else, if nil, all such keys are rejected.
+	HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
+
+	// IsRevoked is called for each certificate so that revocation checking
+	// can be implemented. It should return true if the given certificate
+	// is revoked and false otherwise. If nil, no certificates are
+	// considered to have been revoked.
+	IsRevoked func(cert *Certificate) bool
+}
+
+// CheckHostKey checks a host key certificate. This method can be
+// plugged into ClientConfig.HostKeyCallback.
+func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
+	cert, ok := key.(*Certificate)
+	if !ok {
+		if c.HostKeyFallback != nil {
+			return c.HostKeyFallback(addr, remote, key)
+		}
+		return errors.New("ssh: non-certificate host key")
+	}
+	if cert.CertType != HostCert {
+		return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
+	}
+
+	return c.CheckCert(addr, cert)
+}
+
+// Authenticate checks a user certificate. Authenticate can be used as
+// a value for ServerConfig.PublicKeyCallback.
+func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
+	cert, ok := pubKey.(*Certificate)
+	if !ok {
+		if c.UserKeyFallback != nil {
+			return c.UserKeyFallback(conn, pubKey)
+		}
+		return nil, errors.New("ssh: normal key pairs not accepted")
+	}
+
+	if cert.CertType != UserCert {
+		return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
+	}
+
+	if err := c.CheckCert(conn.User(), cert); err != nil {
+		return nil, err
+	}
+
+	return &cert.Permissions, nil
+}
+
+// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
+// the signature of the certificate.
+func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
+	if c.IsRevoked != nil && c.IsRevoked(cert) {
+		return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
+	}
+
+	for opt, _ := range cert.CriticalOptions {
+		// sourceAddressCriticalOption will be enforced by
+		// serverAuthenticate
+		if opt == sourceAddressCriticalOption {
+			continue
+		}
+
+		found := false
+		for _, supp := range c.SupportedCriticalOptions {
+			if supp == opt {
+				found = true
+				break
+			}
+		}
+		if !found {
+			return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
+		}
+	}
+
+	if len(cert.ValidPrincipals) > 0 {
+		// By default, certs are valid for all users/hosts.
+		found := false
+		for _, p := range cert.ValidPrincipals {
+			if p == principal {
+				found = true
+				break
+			}
+		}
+		if !found {
+			return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
+		}
+	}
+
+	if !c.IsAuthority(cert.SignatureKey) {
+		return fmt.Errorf("ssh: certificate signed by unrecognized authority")
+	}
+
+	clock := c.Clock
+	if clock == nil {
+		clock = time.Now
+	}
+
+	unixNow := clock().Unix()
+	if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
+		return fmt.Errorf("ssh: cert is not yet valid")
+	}
+	if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) {
+		return fmt.Errorf("ssh: cert has expired")
+	}
+	if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
+		return fmt.Errorf("ssh: certificate signature does not verify")
+	}
+
+	return nil
+}
+
+// SignCert sets c.SignatureKey to the authority's public key and stores a
+// Signature, by authority, in the certificate.
+func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
+	c.Nonce = make([]byte, 32)
+	if _, err := io.ReadFull(rand, c.Nonce); err != nil {
+		return err
+	}
+	c.SignatureKey = authority.PublicKey()
+
+	sig, err := authority.Sign(rand, c.bytesForSigning())
+	if err != nil {
+		return err
+	}
+	c.Signature = sig
+	return nil
 }
 
 var certAlgoNames = map[string]string{
@@ -100,260 +387,69 @@
 	panic("unknown cert algorithm")
 }
 
-func (cert *OpenSSHCertV01) marshal(includeAlgo, includeSig bool) []byte {
-	algoName := cert.PublicKeyAlgo()
-	pubKey := cert.Key.Marshal()
-	sigKey := MarshalPublicKey(cert.SignatureKey)
-
-	var length int
-	if includeAlgo {
-		length += stringLength(len(algoName))
-	}
-	length += stringLength(len(cert.Nonce))
-	length += len(pubKey)
-	length += 8 // Length of Serial
-	length += 4 // Length of Type
-	length += stringLength(len(cert.KeyId))
-	length += lengthPrefixedNameListLength(cert.ValidPrincipals)
-	length += 8 // Length of ValidAfter
-	length += 8 // Length of ValidBefore
-	length += tupleListLength(cert.CriticalOptions)
-	length += tupleListLength(cert.Extensions)
-	length += stringLength(len(cert.Reserved))
-	length += stringLength(len(sigKey))
-	if includeSig {
-		length += signatureLength(cert.Signature)
-	}
-
-	ret := make([]byte, length)
-	r := ret
-	if includeAlgo {
-		r = marshalString(r, []byte(algoName))
-	}
-	r = marshalString(r, cert.Nonce)
-	copy(r, pubKey)
-	r = r[len(pubKey):]
-	r = marshalUint64(r, cert.Serial)
-	r = marshalUint32(r, cert.Type)
-	r = marshalString(r, []byte(cert.KeyId))
-	r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
-	r = marshalUint64(r, uint64(cert.ValidAfter))
-	r = marshalUint64(r, uint64(cert.ValidBefore))
-	r = marshalTupleList(r, cert.CriticalOptions)
-	r = marshalTupleList(r, cert.Extensions)
-	r = marshalString(r, cert.Reserved)
-	r = marshalString(r, sigKey)
-	if includeSig {
-		r = marshalSignature(r, cert.Signature)
-	}
-	if len(r) > 0 {
-		panic("ssh: internal error, marshaling certificate did not fill the entire buffer")
-	}
-	return ret
+func (cert *Certificate) bytesForSigning() []byte {
+	c2 := *cert
+	c2.Signature = nil
+	out := c2.Marshal()
+	// Drop trailing signature length.
+	return out[:len(out)-4]
 }
 
-func (cert *OpenSSHCertV01) BytesForSigning() []byte {
-	return cert.marshal(true, false)
+// Marshal serializes c into OpenSSH's wire format. It is part of the
+// PublicKey interface.
+func (c *Certificate) Marshal() []byte {
+	generic := genericCertData{
+		Serial:          c.Serial,
+		CertType:        c.CertType,
+		KeyId:           c.KeyId,
+		ValidPrincipals: marshalStringList(c.ValidPrincipals),
+		ValidAfter:      uint64(c.ValidAfter),
+		ValidBefore:     uint64(c.ValidBefore),
+		CriticalOptions: marshalTuples(c.CriticalOptions),
+		Extensions:      marshalTuples(c.Extensions),
+		Reserved:        c.Reserved,
+		SignatureKey:    c.SignatureKey.Marshal(),
+	}
+	if c.Signature != nil {
+		generic.Signature = Marshal(c.Signature)
+	}
+	genericBytes := Marshal(&generic)
+	keyBytes := c.Key.Marshal()
+	_, keyBytes, _ = parseString(keyBytes)
+	prefix := Marshal(&struct {
+		Name  string
+		Nonce []byte
+		Key   []byte `ssh:"rest"`
+	}{c.Type(), c.Nonce, keyBytes})
+
+	result := make([]byte, 0, len(prefix)+len(genericBytes))
+	result = append(result, prefix...)
+	result = append(result, genericBytes...)
+	return result
 }
 
-func (cert *OpenSSHCertV01) Marshal() []byte {
-	return cert.marshal(false, true)
-}
-
-func (c *OpenSSHCertV01) PublicKeyAlgo() string {
-	algo, ok := certAlgoNames[c.Key.PublicKeyAlgo()]
+// Type returns the key name. It is part of the PublicKey interface.
+func (c *Certificate) Type() string {
+	algo, ok := certAlgoNames[c.Key.Type()]
 	if !ok {
 		panic("unknown cert key type")
 	}
 	return algo
 }
 
-func (c *OpenSSHCertV01) PrivateKeyAlgo() string {
-	return c.Key.PrivateKeyAlgo()
-}
-
-func (c *OpenSSHCertV01) Verify(data []byte, sig []byte) bool {
+// Verify verifies a signature against the certificate's public
+// key. It is part of the PublicKey interface.
+func (c *Certificate) Verify(data []byte, sig *Signature) error {
 	return c.Key.Verify(data, sig)
 }
 
-func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) {
-	cert := new(OpenSSHCertV01)
-
-	if cert.Nonce, in, ok = parseString(in); !ok {
-		return
-	}
-
-	privAlgo := certToPrivAlgo(algo)
-	cert.Key, in, ok = parsePubKey(in, privAlgo)
+func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
+	format, in, ok := parseString(in)
 	if !ok {
 		return
 	}
 
-	// We test PublicKeyAlgo to make sure we don't use some weird sub-cert.
-	if cert.Key.PublicKeyAlgo() != privAlgo {
-		ok = false
-		return
-	}
-
-	if cert.Serial, in, ok = parseUint64(in); !ok {
-		return
-	}
-
-	if cert.Type, in, ok = parseUint32(in); !ok {
-		return
-	}
-
-	keyId, in, ok := parseString(in)
-	if !ok {
-		return
-	}
-	cert.KeyId = string(keyId)
-
-	if cert.ValidPrincipals, in, ok = parseLengthPrefixedNameList(in); !ok {
-		return
-	}
-
-	va, in, ok := parseUint64(in)
-	if !ok {
-		return
-	}
-	cert.ValidAfter = CertTime(va)
-
-	vb, in, ok := parseUint64(in)
-	if !ok {
-		return
-	}
-	cert.ValidBefore = CertTime(vb)
-
-	if cert.CriticalOptions, in, ok = parseTupleList(in); !ok {
-		return
-	}
-
-	if cert.Extensions, in, ok = parseTupleList(in); !ok {
-		return
-	}
-
-	if cert.Reserved, in, ok = parseString(in); !ok {
-		return
-	}
-
-	sigKey, in, ok := parseString(in)
-	if !ok {
-		return
-	}
-	if cert.SignatureKey, _, ok = ParsePublicKey(sigKey); !ok {
-		return
-	}
-
-	if cert.Signature, in, ok = parseSignature(in); !ok {
-		return
-	}
-
-	ok = true
-	return cert, in, ok
-}
-
-func lengthPrefixedNameListLength(namelist []string) int {
-	length := 4 // length prefix for list
-	for _, name := range namelist {
-		length += 4 // length prefix for name
-		length += len(name)
-	}
-	return length
-}
-
-func marshalLengthPrefixedNameList(to []byte, namelist []string) []byte {
-	length := uint32(lengthPrefixedNameListLength(namelist) - 4)
-	to = marshalUint32(to, length)
-	for _, name := range namelist {
-		to = marshalString(to, []byte(name))
-	}
-	return to
-}
-
-func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool) {
-	list, rest, ok := parseString(in)
-	if !ok {
-		return
-	}
-
-	for len(list) > 0 {
-		var next []byte
-		if next, list, ok = parseString(list); !ok {
-			return nil, nil, false
-		}
-		out = append(out, string(next))
-	}
-	ok = true
-	return
-}
-
-func tupleListLength(tupleList []tuple) int {
-	length := 4 // length prefix for list
-	for _, t := range tupleList {
-		length += 4 // length prefix for t.Name
-		length += len(t.Name)
-		length += 4 // length prefix for t.Data
-		length += len(t.Data)
-	}
-	return length
-}
-
-func marshalTupleList(to []byte, tuplelist []tuple) []byte {
-	length := uint32(tupleListLength(tuplelist) - 4)
-	to = marshalUint32(to, length)
-	for _, t := range tuplelist {
-		to = marshalString(to, []byte(t.Name))
-		to = marshalString(to, []byte(t.Data))
-	}
-	return to
-}
-
-func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
-	list, rest, ok := parseString(in)
-	if !ok {
-		return
-	}
-
-	for len(list) > 0 {
-		var name, data []byte
-		var ok bool
-		name, list, ok = parseString(list)
-		if !ok {
-			return nil, nil, false
-		}
-		data, list, ok = parseString(list)
-		if !ok {
-			return nil, nil, false
-		}
-		out = append(out, tuple{string(name), string(data)})
-	}
-	ok = true
-	return
-}
-
-func signatureLength(sig *signature) int {
-	length := 4 // length prefix for signature
-	length += stringLength(len(sig.Format))
-	length += stringLength(len(sig.Blob))
-	return length
-}
-
-func marshalSignature(to []byte, sig *signature) []byte {
-	length := uint32(signatureLength(sig) - 4)
-	to = marshalUint32(to, length)
-	to = marshalString(to, []byte(sig.Format))
-	to = marshalString(to, sig.Blob)
-	return to
-}
-
-func parseSignatureBody(in []byte) (out *signature, rest []byte, ok bool) {
-	var format []byte
-	if format, in, ok = parseString(in); !ok {
-		return
-	}
-
-	out = &signature{
+	out = &Signature{
 		Format: string(format),
 	}
 
@@ -364,14 +460,14 @@
 	return out, in, ok
 }
 
-func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
-	var sigBytes []byte
-	if sigBytes, rest, ok = parseString(in); !ok {
+func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
+	sigBytes, rest, ok := parseString(in)
+	if !ok {
 		return
 	}
 
-	out, sigBytes, ok = parseSignatureBody(sigBytes)
-	if !ok || len(sigBytes) > 0 {
+	out, trailing, ok := parseSignatureBody(sigBytes)
+	if !ok || len(trailing) > 0 {
 		return nil, nil, false
 	}
 	return
diff --git a/ssh/certs_test.go b/ssh/certs_test.go
index 3cec28e..7d1b00f 100644
--- a/ssh/certs_test.go
+++ b/ssh/certs_test.go
@@ -6,7 +6,9 @@
 
 import (
 	"bytes"
+	"crypto/rand"
 	"testing"
+	"time"
 )
 
 // Cert generated by ssh-keygen 6.0p1 Debian-4.
@@ -16,16 +18,16 @@
 func TestParseCert(t *testing.T) {
 	authKeyBytes := []byte(exampleSSHCert)
 
-	key, _, _, rest, ok := ParseAuthorizedKey(authKeyBytes)
-	if !ok {
-		t.Fatalf("could not parse certificate")
+	key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
+	if err != nil {
+		t.Fatalf("ParseAuthorizedKey: %v", err)
 	}
 	if len(rest) > 0 {
 		t.Errorf("rest: got %q, want empty", rest)
 	}
 
-	if _, ok = key.(*OpenSSHCertV01); !ok {
-		t.Fatalf("got %#v, want *OpenSSHCertV01", key)
+	if _, ok := key.(*Certificate); !ok {
+		t.Fatalf("got %#v, want *Certificate", key)
 	}
 
 	marshaled := MarshalAuthorizedKey(key)
@@ -37,19 +39,118 @@
 	}
 }
 
-func TestVerifyCert(t *testing.T) {
-	key, _, _, _, _ := ParseAuthorizedKey([]byte(exampleSSHCert))
-	validCert := key.(*OpenSSHCertV01)
-	if ok := validateOpenSSHCertV01Signature(validCert); !ok {
-		t.Error("Unable to validate certificate!")
+func TestValidateCert(t *testing.T) {
+	key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
+	if err != nil {
+		t.Fatalf("ParseAuthorizedKey: %v", err)
+	}
+	validCert, ok := key.(*Certificate)
+	if !ok {
+		t.Fatalf("got %v (%T), want *Certificate", key, key)
+	}
+	checker := CertChecker{}
+	checker.IsAuthority = func(k PublicKey) bool {
+		return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
 	}
 
-	invalidCert := &OpenSSHCertV01{
-		Key:          rsaKey.PublicKey(),
-		SignatureKey: ecdsaKey.PublicKey(),
-		Signature:    &signature{},
+	if err := checker.CheckCert("user", validCert); err != nil {
+		t.Errorf("Unable to validate certificate: %v", err)
 	}
-	if ok := validateOpenSSHCertV01Signature(invalidCert); ok {
-		t.Error("Invalid cert signature passed validation!")
+	invalidCert := &Certificate{
+		Key:          testPublicKeys["rsa"],
+		SignatureKey: testPublicKeys["ecdsa"],
+		ValidBefore:  CertTimeInfinity,
+		Signature:    &Signature{},
+	}
+	if err := checker.CheckCert("user", invalidCert); err == nil {
+		t.Error("Invalid cert signature passed validation")
+	}
+}
+
+func TestValidateCertTime(t *testing.T) {
+	cert := Certificate{
+		ValidPrincipals: []string{"user"},
+		Key:             testPublicKeys["rsa"],
+		ValidAfter:      50,
+		ValidBefore:     100,
+	}
+
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+
+	for ts, ok := range map[int64]bool{
+		25:  false,
+		50:  true,
+		99:  true,
+		100: false,
+		125: false,
+	} {
+		checker := CertChecker{
+			Clock: func() time.Time { return time.Unix(ts, 0) },
+		}
+		checker.IsAuthority = func(k PublicKey) bool {
+			return bytes.Equal(k.Marshal(),
+				testPublicKeys["ecdsa"].Marshal())
+		}
+
+		if v := checker.CheckCert("user", &cert); (v == nil) != ok {
+			t.Errorf("Authenticate(%d): %v", ts, v)
+		}
+	}
+}
+
+// TODO(hanwen): tests for
+//
+// host keys:
+// * fallbacks
+
+func TestHostKeyCert(t *testing.T) {
+	cert := &Certificate{
+		ValidPrincipals: []string{"hostname", "hostname.domain"},
+		Key:             testPublicKeys["rsa"],
+		ValidBefore:     CertTimeInfinity,
+		CertType:        HostCert,
+	}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+
+	checker := &CertChecker{
+		IsAuthority: func(p PublicKey) bool {
+			return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
+		},
+	}
+
+	certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+	if err != nil {
+		t.Errorf("NewCertSigner: %v", err)
+	}
+
+	for _, name := range []string{"hostname", "otherhost"} {
+		c1, c2, err := netPipe()
+		if err != nil {
+			t.Fatalf("netPipe: %v", err)
+		}
+		defer c1.Close()
+		defer c2.Close()
+
+		go func() {
+			conf := ServerConfig{
+				NoClientAuth: true,
+			}
+			conf.AddHostKey(certSigner)
+			_, _, _, err := NewServerConn(c1, &conf)
+			if err != nil {
+				t.Fatalf("NewServerConn: %v", err)
+			}
+		}()
+
+		config := &ClientConfig{
+			User:            "user",
+			HostKeyCallback: checker.CheckHostKey,
+		}
+		_, _, _, err = NewClientConn(c2, name, config)
+
+		succeed := name == "hostname"
+		if (err == nil) != succeed {
+			t.Fatalf("NewClientConn(%q): %v", name, err)
+		}
 	}
 }
diff --git a/ssh/channel.go b/ssh/channel.go
index c5413c9..8e777bb 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -5,71 +5,100 @@
 package ssh
 
 import (
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
+	"log"
 	"sync"
-	"sync/atomic"
 )
 
-// extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
-// section 5.2.
-type extendedDataTypeCode uint32
-
 const (
-	// extendedDataStderr is the extended data type that is used for stderr.
-	extendedDataStderr extendedDataTypeCode = 1
-
-	// minPacketLength defines the smallest valid packet
 	minPacketLength = 9
-
-	// channelMaxPacketSize defines the maximum packet size advertised in open messages
-	channelMaxPacketSize = 1 << 15 // RFC 4253 6.1, minimum 32 KiB
-
-	// channelWindowSize defines the window size advertised in open messages
-	channelWindowSize = 64 * channelMaxPacketSize // Like OpenSSH
+	// channelMaxPacket contains the maximum number of bytes that will be
+	// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
+	// the minimum.
+	channelMaxPacket = 1 << 15
+	// We follow OpenSSH here.
+	channelWindowSize = 64 * channelMaxPacket
 )
 
-// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
-// SSH connection. Channel.Read can return a ChannelRequest as an error.
-type Channel interface {
-	// Accept accepts the channel creation request.
-	Accept() error
-	// Reject rejects the channel creation request. After calling this, no
-	// other methods on the Channel may be called. If they are then the
-	// peer is likely to signal a protocol error and drop the connection.
+// NewChannel represents an incoming request to a channel. It must either be
+// accepted for use by calling Accept, or rejected by calling Reject.
+type NewChannel interface {
+	// Accept accepts the channel creation request. It returns the Channel
+	// and a Go channel containing SSH requests. The Go channel must be
+	// serviced otherwise the Channel will hang.
+	Accept() (Channel, <-chan *Request, error)
+
+	// Reject rejects the channel creation request. After calling
+	// this, no other methods on the Channel may be called.
 	Reject(reason RejectionReason, message string) error
 
-	// Read may return a ChannelRequest as an error.
-	Read(data []byte) (int, error)
-	Write(data []byte) (int, error)
-	Close() error
-
-	// Stderr returns an io.Writer that writes to this channel with the
-	// extended data type set to stderr.
-	Stderr() io.Writer
-
-	// AckRequest either sends an ack or nack to the channel request.
-	AckRequest(ok bool) error
-
 	// ChannelType returns the type of the channel, as supplied by the
 	// client.
 	ChannelType() string
+
 	// ExtraData returns the arbitrary payload for this channel, as supplied
 	// by the client. This data is specific to the channel type.
 	ExtraData() []byte
 }
 
-// ChannelRequest represents a request sent on a channel, outside of the normal
-// stream of bytes. It may result from calling Read on a Channel.
-type ChannelRequest struct {
-	Request   string
-	WantReply bool
-	Payload   []byte
+// A Channel is an ordered, reliable, flow-controlled, duplex stream
+// that is multiplexed over an SSH connection.
+type Channel interface {
+	// Read reads up to len(data) bytes from the channel.
+	Read(data []byte) (int, error)
+
+	// Write writes len(data) bytes to the channel.
+	Write(data []byte) (int, error)
+
+	// Close signals end of channel use. No data may be sent after this
+	// call.
+	Close() error
+
+	// CloseWrite signals the end of sending in-band
+	// data. Requests may still be sent, and the other side may
+	// still send data
+	CloseWrite() error
+
+	// SendRequest sends a channel request.  If wantReply is true,
+	// it will wait for a reply and return the result as a
+	// boolean, otherwise the return value will be false. Channel
+	// requests are out-of-band messages so they may be sent even
+	// if the data stream is closed or blocked by flow control.
+	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
+
+	// Stderr returns an io.ReadWriter that writes to this channel with the
+	// extended data type set to stderr.
+	Stderr() io.ReadWriter
 }
 
-func (c ChannelRequest) Error() string {
-	return "ssh: channel request received"
+// Request is a request sent outside of the normal stream of
+// data. Requests can either be specific to an SSH channel, or they
+// can be global.
+type Request struct {
+	Type      string
+	WantReply bool
+	Payload   []byte
+
+	ch  *channel
+	mux *mux
+}
+
+// Reply sends a response to a request. It must be called for all requests
+// where WantReply is true and is a no-op otherwise. The payload argument is
+// ignored for replies to channel-specific requests.
+func (r *Request) Reply(ok bool, payload []byte) error {
+	if !r.WantReply {
+		return nil
+	}
+
+	if r.ch == nil {
+		return r.mux.ackRequest(ok, payload)
+	}
+
+	return r.ch.ackRequest(ok)
 }
 
 // RejectionReason is an enumeration used when rejecting channel creation
@@ -98,464 +127,6 @@
 	return fmt.Sprintf("unknown reason %d", int(r))
 }
 
-type channel struct {
-	packetConn        // the underlying transport
-	localId, remoteId uint32
-	remoteWin         window
-	maxPacket         uint32
-	isClosed          uint32 // atomic bool, non zero if true
-}
-
-func (c *channel) sendWindowAdj(n int) error {
-	msg := windowAdjustMsg{
-		PeersId:         c.remoteId,
-		AdditionalBytes: uint32(n),
-	}
-	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
-}
-
-// sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
-func (c *channel) sendEOF() error {
-	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
-// sendClose informs the remote side of our intent to close the channel.
-func (c *channel) sendClose() error {
-	return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
-func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
-	reject := channelOpenFailureMsg{
-		PeersId:  c.remoteId,
-		Reason:   reason,
-		Message:  message,
-		Language: "en",
-	}
-	return c.writePacket(marshal(msgChannelOpenFailure, reject))
-}
-
-func (c *channel) writePacket(b []byte) error {
-	if c.closed() {
-		return io.EOF
-	}
-	if uint32(len(b)) > c.maxPacket {
-		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
-	}
-	return c.packetConn.writePacket(b)
-}
-
-func (c *channel) closed() bool {
-	return atomic.LoadUint32(&c.isClosed) > 0
-}
-
-func (c *channel) setClosed() bool {
-	return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
-}
-
-type serverChan struct {
-	channel
-	// immutable once created
-	chanType  string
-	extraData []byte
-
-	serverConn  *ServerConn
-	myWindow    uint32
-	theyClosed  bool // indicates the close msg has been received from the remote side
-	theySentEOF bool
-	isDead      uint32
-	err         error
-
-	pendingRequests []ChannelRequest
-	pendingData     []byte
-	head, length    int
-
-	// This lock is inferior to serverConn.lock
-	cond *sync.Cond
-}
-
-func (c *serverChan) Accept() error {
-	c.serverConn.lock.Lock()
-	defer c.serverConn.lock.Unlock()
-
-	if c.serverConn.err != nil {
-		return c.serverConn.err
-	}
-
-	confirm := channelOpenConfirmMsg{
-		PeersId:       c.remoteId,
-		MyId:          c.localId,
-		MyWindow:      c.myWindow,
-		MaxPacketSize: c.maxPacket,
-	}
-	return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
-}
-
-func (c *serverChan) Reject(reason RejectionReason, message string) error {
-	c.serverConn.lock.Lock()
-	defer c.serverConn.lock.Unlock()
-
-	if c.serverConn.err != nil {
-		return c.serverConn.err
-	}
-
-	return c.sendChannelOpenFailure(reason, message)
-}
-
-func (c *serverChan) handlePacket(packet interface{}) {
-	c.cond.L.Lock()
-	defer c.cond.L.Unlock()
-
-	switch packet := packet.(type) {
-	case *channelRequestMsg:
-		req := ChannelRequest{
-			Request:   packet.Request,
-			WantReply: packet.WantReply,
-			Payload:   packet.RequestSpecificData,
-		}
-
-		c.pendingRequests = append(c.pendingRequests, req)
-		c.cond.Signal()
-	case *channelCloseMsg:
-		c.theyClosed = true
-		c.cond.Signal()
-	case *channelEOFMsg:
-		c.theySentEOF = true
-		c.cond.Signal()
-	case *windowAdjustMsg:
-		if !c.remoteWin.add(packet.AdditionalBytes) {
-			panic("illegal window update")
-		}
-	default:
-		panic("unknown packet type")
-	}
-}
-
-func (c *serverChan) handleData(data []byte) {
-	c.cond.L.Lock()
-	defer c.cond.L.Unlock()
-
-	// The other side should never send us more than our window.
-	if len(data)+c.length > len(c.pendingData) {
-		// TODO(agl): we should tear down the channel with a protocol
-		// error.
-		return
-	}
-
-	c.myWindow -= uint32(len(data))
-	for i := 0; i < 2; i++ {
-		tail := c.head + c.length
-		if tail >= len(c.pendingData) {
-			tail -= len(c.pendingData)
-		}
-		n := copy(c.pendingData[tail:], data)
-		data = data[n:]
-		c.length += n
-	}
-
-	c.cond.Signal()
-}
-
-func (c *serverChan) Stderr() io.Writer {
-	return extendedDataChannel{c: c, t: extendedDataStderr}
-}
-
-// extendedDataChannel is an io.Writer that writes any data to c as extended
-// data of the given type.
-type extendedDataChannel struct {
-	t extendedDataTypeCode
-	c *serverChan
-}
-
-func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
-	const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
-	c := edc.c
-	for len(data) > 0 {
-		space := min(c.maxPacket-headerLength, len(data))
-		if space, err = c.getWindowSpace(space); err != nil {
-			return 0, err
-		}
-		todo := data
-		if uint32(len(todo)) > space {
-			todo = todo[:space]
-		}
-
-		packet := make([]byte, headerLength+len(todo))
-		packet[0] = msgChannelExtendedData
-		marshalUint32(packet[1:], c.remoteId)
-		marshalUint32(packet[5:], uint32(edc.t))
-		marshalUint32(packet[9:], uint32(len(todo)))
-		copy(packet[13:], todo)
-
-		if err = c.writePacket(packet); err != nil {
-			return
-		}
-
-		n += len(todo)
-		data = data[len(todo):]
-	}
-
-	return
-}
-
-func (c *serverChan) Read(data []byte) (n int, err error) {
-	n, err, windowAdjustment := c.read(data)
-
-	if windowAdjustment > 0 {
-		packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
-			PeersId:         c.remoteId,
-			AdditionalBytes: windowAdjustment,
-		})
-		err = c.writePacket(packet)
-	}
-
-	return
-}
-
-func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
-	c.cond.L.Lock()
-	defer c.cond.L.Unlock()
-
-	if c.err != nil {
-		return 0, c.err, 0
-	}
-
-	for {
-		if c.theySentEOF || c.theyClosed || c.dead() {
-			return 0, io.EOF, 0
-		}
-
-		if len(c.pendingRequests) > 0 {
-			req := c.pendingRequests[0]
-			if len(c.pendingRequests) == 1 {
-				c.pendingRequests = nil
-			} else {
-				oldPendingRequests := c.pendingRequests
-				c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
-				copy(c.pendingRequests, oldPendingRequests[1:])
-			}
-
-			return 0, req, 0
-		}
-
-		if c.length > 0 {
-			tail := min(uint32(c.head+c.length), len(c.pendingData))
-			n = copy(data, c.pendingData[c.head:tail])
-			c.head += n
-			c.length -= n
-			if c.head == len(c.pendingData) {
-				c.head = 0
-			}
-
-			windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
-			if windowAdjustment < uint32(len(c.pendingData)/2) {
-				windowAdjustment = 0
-			}
-			c.myWindow += windowAdjustment
-
-			return
-		}
-
-		c.cond.Wait()
-	}
-
-	panic("unreachable")
-}
-
-// getWindowSpace takes, at most, max bytes of space from the peer's window. It
-// returns the number of bytes actually reserved.
-func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
-	if c.dead() || c.closed() {
-		return 0, io.EOF
-	}
-	return c.remoteWin.reserve(max), nil
-}
-
-func (c *serverChan) dead() bool {
-	return atomic.LoadUint32(&c.isDead) > 0
-}
-
-func (c *serverChan) setDead() {
-	atomic.StoreUint32(&c.isDead, 1)
-}
-
-func (c *serverChan) Write(data []byte) (n int, err error) {
-	const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
-	for len(data) > 0 {
-		space := min(c.maxPacket-headerLength, len(data))
-		if space, err = c.getWindowSpace(space); err != nil {
-			return 0, err
-		}
-		todo := data
-		if uint32(len(todo)) > space {
-			todo = todo[:space]
-		}
-
-		packet := make([]byte, headerLength+len(todo))
-		packet[0] = msgChannelData
-		marshalUint32(packet[1:], c.remoteId)
-		marshalUint32(packet[5:], uint32(len(todo)))
-		copy(packet[9:], todo)
-
-		if err = c.writePacket(packet); err != nil {
-			return
-		}
-
-		n += len(todo)
-		data = data[len(todo):]
-	}
-
-	return
-}
-
-// Close signals the intent to close the channel.
-func (c *serverChan) Close() error {
-	c.serverConn.lock.Lock()
-	defer c.serverConn.lock.Unlock()
-
-	if c.serverConn.err != nil {
-		return c.serverConn.err
-	}
-
-	if !c.setClosed() {
-		return errors.New("ssh: channel already closed")
-	}
-	return c.sendClose()
-}
-
-func (c *serverChan) AckRequest(ok bool) error {
-	c.serverConn.lock.Lock()
-	defer c.serverConn.lock.Unlock()
-
-	if c.serverConn.err != nil {
-		return c.serverConn.err
-	}
-
-	if !ok {
-		ack := channelRequestFailureMsg{
-			PeersId: c.remoteId,
-		}
-		return c.writePacket(marshal(msgChannelFailure, ack))
-	}
-
-	ack := channelRequestSuccessMsg{
-		PeersId: c.remoteId,
-	}
-	return c.writePacket(marshal(msgChannelSuccess, ack))
-}
-
-func (c *serverChan) ChannelType() string {
-	return c.chanType
-}
-
-func (c *serverChan) ExtraData() []byte {
-	return c.extraData
-}
-
-// A clientChan represents a single RFC 4254 channel multiplexed
-// over a SSH connection.
-type clientChan struct {
-	channel
-	stdin  *chanWriter
-	stdout *chanReader
-	stderr *chanReader
-	msg    chan interface{}
-}
-
-// newClientChan returns a partially constructed *clientChan
-// using the local id provided. To be usable clientChan.remoteId
-// needs to be assigned once known.
-func newClientChan(cc packetConn, id uint32) *clientChan {
-	c := &clientChan{
-		channel: channel{
-			packetConn: cc,
-			localId:    id,
-			remoteWin:  window{Cond: newCond()},
-		},
-		msg: make(chan interface{}, 16),
-	}
-	c.stdin = &chanWriter{
-		channel: &c.channel,
-	}
-	c.stdout = &chanReader{
-		channel: &c.channel,
-		buffer:  newBuffer(),
-	}
-	c.stderr = &chanReader{
-		channel: &c.channel,
-		buffer:  newBuffer(),
-	}
-	return c
-}
-
-// waitForChannelOpenResponse, if successful, fills out
-// the remoteId and records any initial window advertisement.
-func (c *clientChan) waitForChannelOpenResponse() error {
-	switch msg := (<-c.msg).(type) {
-	case *channelOpenConfirmMsg:
-		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
-			return errors.New("ssh: invalid MaxPacketSize from peer")
-		}
-		// fixup remoteId field
-		c.remoteId = msg.MyId
-		c.maxPacket = msg.MaxPacketSize
-		c.remoteWin.add(msg.MyWindow)
-		return nil
-	case *channelOpenFailureMsg:
-		return errors.New(safeString(msg.Message))
-	}
-	return errors.New("ssh: unexpected packet")
-}
-
-// Close signals the intent to close the channel.
-func (c *clientChan) Close() error {
-	if !c.setClosed() {
-		return errors.New("ssh: channel already closed")
-	}
-	c.stdout.eof()
-	c.stderr.eof()
-	return c.sendClose()
-}
-
-// A chanWriter represents the stdin of a remote process.
-type chanWriter struct {
-	*channel
-	// indicates the writer has been closed. eof is owned by the
-	// caller of Write/Close.
-	eof bool
-}
-
-// Write writes data to the remote process's standard input.
-func (w *chanWriter) Write(data []byte) (written int, err error) {
-	const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
-	for len(data) > 0 {
-		if w.eof || w.closed() {
-			err = io.EOF
-			return
-		}
-		// never send more data than maxPacket even if
-		// there is sufficient window.
-		n := min(w.maxPacket-headerLength, len(data))
-		r := w.remoteWin.reserve(n)
-		n = r
-		remoteId := w.remoteId
-		packet := []byte{
-			msgChannelData,
-			byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
-			byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
-		}
-		if err = w.writePacket(append(packet, data[:n]...)); err != nil {
-			break
-		}
-		data = data[n:]
-		written += int(n)
-	}
-	return
-}
-
 func min(a uint32, b int) uint32 {
 	if a < uint32(b) {
 		return a
@@ -563,32 +134,475 @@
 	return uint32(b)
 }
 
-func (w *chanWriter) Close() error {
-	w.eof = true
-	return w.sendEOF()
+type channelDirection uint8
+
+const (
+	channelInbound channelDirection = iota
+	channelOutbound
+)
+
+// channel is an implementation of the Channel interface that works
+// with the mux class.
+type channel struct {
+	// R/O after creation
+	chanType          string
+	extraData         []byte
+	localId, remoteId uint32
+
+	// maxIncomingPayload and maxRemotePayload are the maximum
+	// payload sizes of normal and extended data packets for
+	// receiving and sending, respectively. The wire packet will
+	// be 9 or 13 bytes larger (excluding encryption overhead).
+	maxIncomingPayload uint32
+	maxRemotePayload   uint32
+
+	mux *mux
+
+	// decided is set to true if an accept or reject message has been sent
+	// (for outbound channels) or received (for inbound channels).
+	decided bool
+
+	// direction contains either channelOutbound, for channels created
+	// locally, or channelInbound, for channels created by the peer.
+	direction channelDirection
+
+	// Pending internal channel messages.
+	msg chan interface{}
+
+	// Since requests have no ID, there can be only one request
+	// with WantReply=true outstanding.  This lock is held by a
+	// goroutine that has such an outgoing request pending.
+	sentRequestMu sync.Mutex
+
+	incomingRequests chan *Request
+
+	sentEOF bool
+
+	// thread-safe data
+	remoteWin  window
+	pending    *buffer
+	extPending *buffer
+
+	// windowMu protects myWindow, the flow-control window.
+	windowMu sync.Mutex
+	myWindow uint32
+
+	// writeMu serializes calls to mux.conn.writePacket() and
+	// protects sentClose. This mutex must be different from
+	// windowMu, as writePacket can block if there is a key
+	// exchange pending
+	writeMu   sync.Mutex
+	sentClose bool
 }
 
-// A chanReader represents stdout or stderr of a remote process.
-type chanReader struct {
-	*channel // the channel backing this reader
-	*buffer
+// writePacket sends a packet. If the packet is a channel close, it updates
+// sentClose. This method takes the lock c.writeMu.
+func (c *channel) writePacket(packet []byte) error {
+	c.writeMu.Lock()
+	if c.sentClose {
+		c.writeMu.Unlock()
+		return io.EOF
+	}
+	c.sentClose = (packet[0] == msgChannelClose)
+	err := c.mux.conn.writePacket(packet)
+	c.writeMu.Unlock()
+	return err
 }
 
-// Read reads data from the remote process's stdout or stderr.
-func (r *chanReader) Read(buf []byte) (int, error) {
-	n, err := r.buffer.Read(buf)
-	if err != nil {
-		if err == io.EOF {
+func (c *channel) sendMessage(msg interface{}) error {
+	if debugMux {
+		log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
+	}
+
+	p := Marshal(msg)
+	binary.BigEndian.PutUint32(p[1:], c.remoteId)
+	return c.writePacket(p)
+}
+
+// WriteExtended writes data to a specific extended stream. These streams are
+// used, for example, for stderr.
+func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
+	if c.sentEOF {
+		return 0, io.EOF
+	}
+	// 1 byte message type, 4 bytes remoteId, 4 bytes data length
+	opCode := byte(msgChannelData)
+	headerLength := uint32(9)
+	if extendedCode > 0 {
+		headerLength += 4
+		opCode = msgChannelExtendedData
+	}
+
+	for len(data) > 0 {
+		space := min(c.maxRemotePayload, len(data))
+		if space, err = c.remoteWin.reserve(space); err != nil {
 			return n, err
 		}
-		return 0, err
+		todo := data[:space]
+
+		packet := make([]byte, headerLength+uint32(len(todo)))
+		packet[0] = opCode
+		binary.BigEndian.PutUint32(packet[1:], c.remoteId)
+		if extendedCode > 0 {
+			binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
+		}
+		binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
+		copy(packet[headerLength:], todo)
+		if err = c.writePacket(packet); err != nil {
+			return n, err
+		}
+
+		n += len(todo)
+		data = data[len(todo):]
 	}
-	err = r.sendWindowAdj(n)
-	if err == io.EOF && n > 0 {
-		// sendWindowAdjust can return io.EOF if the remote peer has
-		// closed the connection, however we want to defer forwarding io.EOF to the
-		// caller of Read until the buffer has been drained.
-		err = nil
-	}
+
 	return n, err
 }
+
+func (c *channel) handleData(packet []byte) error {
+	headerLen := 9
+	isExtendedData := packet[0] == msgChannelExtendedData
+	if isExtendedData {
+		headerLen = 13
+	}
+	if len(packet) < headerLen {
+		// malformed data packet
+		return parseError(packet[0])
+	}
+
+	var extended uint32
+	if isExtendedData {
+		extended = binary.BigEndian.Uint32(packet[5:])
+	}
+
+	length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
+	if length == 0 {
+		return nil
+	}
+	if length > c.maxIncomingPayload {
+		// TODO(hanwen): should send Disconnect?
+		return errors.New("ssh: incoming packet exceeds maximum payload size")
+	}
+
+	data := packet[headerLen:]
+	if length != uint32(len(data)) {
+		return errors.New("ssh: wrong packet length")
+	}
+
+	c.windowMu.Lock()
+	if c.myWindow < length {
+		c.windowMu.Unlock()
+		// TODO(hanwen): should send Disconnect with reason?
+		return errors.New("ssh: remote side wrote too much")
+	}
+	c.myWindow -= length
+	c.windowMu.Unlock()
+
+	if extended == 1 {
+		c.extPending.write(data)
+	} else if extended > 0 {
+		// discard other extended data.
+	} else {
+		c.pending.write(data)
+	}
+	return nil
+}
+
+func (c *channel) adjustWindow(n uint32) error {
+	c.windowMu.Lock()
+	// Since myWindow is managed on our side, and can never exceed
+	// the initial window setting, we don't worry about overflow.
+	c.myWindow += uint32(n)
+	c.windowMu.Unlock()
+	return c.sendMessage(windowAdjustMsg{
+		AdditionalBytes: uint32(n),
+	})
+}
+
+func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
+	switch extended {
+	case 1:
+		n, err = c.extPending.Read(data)
+	case 0:
+		n, err = c.pending.Read(data)
+	default:
+		return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
+	}
+
+	if n > 0 {
+		err = c.adjustWindow(uint32(n))
+		// sendWindowAdjust can return io.EOF if the remote
+		// peer has closed the connection, however we want to
+		// defer forwarding io.EOF to the caller of Read until
+		// the buffer has been drained.
+		if n > 0 && err == io.EOF {
+			err = nil
+		}
+	}
+
+	return n, err
+}
+
+func (c *channel) close() {
+	c.pending.eof()
+	c.extPending.eof()
+	close(c.msg)
+	close(c.incomingRequests)
+	c.writeMu.Lock()
+	// This is not necesary for a normal channel teardown, but if
+	// there was another error, it is.
+	c.sentClose = true
+	c.writeMu.Unlock()
+	// Unblock writers.
+	c.remoteWin.close()
+}
+
+// responseMessageReceived is called when a success or failure message is
+// received on a channel to check that such a message is reasonable for the
+// given channel.
+func (c *channel) responseMessageReceived() error {
+	if c.direction == channelInbound {
+		return errors.New("ssh: channel response message received on inbound channel")
+	}
+	if c.decided {
+		return errors.New("ssh: duplicate response received for channel")
+	}
+	c.decided = true
+	return nil
+}
+
+func (c *channel) handlePacket(packet []byte) error {
+	switch packet[0] {
+	case msgChannelData, msgChannelExtendedData:
+		return c.handleData(packet)
+	case msgChannelClose:
+		c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
+		c.mux.chanList.remove(c.localId)
+		c.close()
+		return nil
+	case msgChannelEOF:
+		// RFC 4254 is mute on how EOF affects dataExt messages but
+		// it is logical to signal EOF at the same time.
+		c.extPending.eof()
+		c.pending.eof()
+		return nil
+	}
+
+	decoded, err := decode(packet)
+	if err != nil {
+		return err
+	}
+
+	switch msg := decoded.(type) {
+	case *channelOpenFailureMsg:
+		if err := c.responseMessageReceived(); err != nil {
+			return err
+		}
+		c.mux.chanList.remove(msg.PeersId)
+		c.msg <- msg
+	case *channelOpenConfirmMsg:
+		if err := c.responseMessageReceived(); err != nil {
+			return err
+		}
+		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+			return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
+		}
+		c.remoteId = msg.MyId
+		c.maxRemotePayload = msg.MaxPacketSize
+		c.remoteWin.add(msg.MyWindow)
+		c.msg <- msg
+	case *windowAdjustMsg:
+		if !c.remoteWin.add(msg.AdditionalBytes) {
+			return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
+		}
+	case *channelRequestMsg:
+		req := Request{
+			Type:      msg.Request,
+			WantReply: msg.WantReply,
+			Payload:   msg.RequestSpecificData,
+			ch:        c,
+		}
+
+		c.incomingRequests <- &req
+	default:
+		c.msg <- msg
+	}
+	return nil
+}
+
+func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
+	ch := &channel{
+		remoteWin:        window{Cond: newCond()},
+		myWindow:         channelWindowSize,
+		pending:          newBuffer(),
+		extPending:       newBuffer(),
+		direction:        direction,
+		incomingRequests: make(chan *Request, 16),
+		msg:              make(chan interface{}, 16),
+		chanType:         chanType,
+		extraData:        extraData,
+		mux:              m,
+	}
+	ch.localId = m.chanList.add(ch)
+	return ch
+}
+
+var errUndecided = errors.New("ssh: must Accept or Reject channel")
+var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
+
+type extChannel struct {
+	code uint32
+	ch   *channel
+}
+
+func (e *extChannel) Write(data []byte) (n int, err error) {
+	return e.ch.WriteExtended(data, e.code)
+}
+
+func (e *extChannel) Read(data []byte) (n int, err error) {
+	return e.ch.ReadExtended(data, e.code)
+}
+
+func (c *channel) Accept() (Channel, <-chan *Request, error) {
+	if c.decided {
+		return nil, nil, errDecidedAlready
+	}
+	c.maxIncomingPayload = channelMaxPacket
+	confirm := channelOpenConfirmMsg{
+		PeersId:       c.remoteId,
+		MyId:          c.localId,
+		MyWindow:      c.myWindow,
+		MaxPacketSize: c.maxIncomingPayload,
+	}
+	c.decided = true
+	if err := c.sendMessage(confirm); err != nil {
+		return nil, nil, err
+	}
+
+	return c, c.incomingRequests, nil
+}
+
+func (ch *channel) Reject(reason RejectionReason, message string) error {
+	if ch.decided {
+		return errDecidedAlready
+	}
+	reject := channelOpenFailureMsg{
+		PeersId:  ch.remoteId,
+		Reason:   reason,
+		Message:  message,
+		Language: "en",
+	}
+	ch.decided = true
+	return ch.sendMessage(reject)
+}
+
+func (ch *channel) Read(data []byte) (int, error) {
+	if !ch.decided {
+		return 0, errUndecided
+	}
+	return ch.ReadExtended(data, 0)
+}
+
+func (ch *channel) Write(data []byte) (int, error) {
+	if !ch.decided {
+		return 0, errUndecided
+	}
+	return ch.WriteExtended(data, 0)
+}
+
+func (ch *channel) CloseWrite() error {
+	if !ch.decided {
+		return errUndecided
+	}
+	ch.sentEOF = true
+	return ch.sendMessage(channelEOFMsg{
+		PeersId: ch.remoteId})
+}
+
+func (ch *channel) Close() error {
+	if !ch.decided {
+		return errUndecided
+	}
+
+	return ch.sendMessage(channelCloseMsg{
+		PeersId: ch.remoteId})
+}
+
+// Extended returns an io.ReadWriter that sends and receives data on the given,
+// SSH extended stream. Such streams are used, for example, for stderr.
+func (ch *channel) Extended(code uint32) io.ReadWriter {
+	if !ch.decided {
+		return nil
+	}
+	return &extChannel{code, ch}
+}
+
+func (ch *channel) Stderr() io.ReadWriter {
+	return ch.Extended(1)
+}
+
+func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+	if !ch.decided {
+		return false, errUndecided
+	}
+
+	if wantReply {
+		ch.sentRequestMu.Lock()
+		defer ch.sentRequestMu.Unlock()
+	}
+
+	msg := channelRequestMsg{
+		PeersId:             ch.remoteId,
+		Request:             name,
+		WantReply:           wantReply,
+		RequestSpecificData: payload,
+	}
+
+	if err := ch.sendMessage(msg); err != nil {
+		return false, err
+	}
+
+	if wantReply {
+		m, ok := (<-ch.msg)
+		if !ok {
+			return false, io.EOF
+		}
+		switch m.(type) {
+		case *channelRequestFailureMsg:
+			return false, nil
+		case *channelRequestSuccessMsg:
+			return true, nil
+		default:
+			return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
+		}
+	}
+
+	return false, nil
+}
+
+// ackRequest either sends an ack or nack to the channel request.
+func (ch *channel) ackRequest(ok bool) error {
+	if !ch.decided {
+		return errUndecided
+	}
+
+	var msg interface{}
+	if !ok {
+		msg = channelRequestFailureMsg{
+			PeersId: ch.remoteId,
+		}
+	} else {
+		msg = channelRequestSuccessMsg{
+			PeersId: ch.remoteId,
+		}
+	}
+	return ch.sendMessage(msg)
+}
+
+func (ch *channel) ChannelType() string {
+	return ch.chanType
+}
+
+func (ch *channel) ExtraData() []byte {
+	return ch.extraData
+}
diff --git a/ssh/cipher.go b/ssh/cipher.go
index bc2e983..a58f10b 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -8,11 +8,28 @@
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/rc4"
+	"crypto/subtle"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"hash"
+	"io"
 )
 
-// streamDump is used to dump the initial keystream for stream ciphers. It is a
-// a write-only buffer, and not intended for reading so do not require a mutex.
-var streamDump [512]byte
+const (
+	packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
+
+	// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
+	// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
+	// indicates implementations SHOULD be able to handle larger packet sizes, but then
+	// waffles on about reasonable limits.
+	//
+	// OpenSSH caps their maxPacket at 256kB so we choose to do
+	// the same. maxPacket is also used to ensure that uint32
+	// length fields do not overflow, so it should remain well
+	// below 4G.
+	maxPacket = 256 * 1024
+)
 
 // noneCipher implements cipher.Stream and provides no encryption. It is used
 // by the transport before the first key-exchange.
@@ -34,14 +51,14 @@
 	return rc4.NewCipher(key)
 }
 
-type cipherMode struct {
+type streamCipherMode struct {
 	keySize    int
 	ivSize     int
 	skip       int
 	createFunc func(key, iv []byte) (cipher.Stream, error)
 }
 
-func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
+func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
 	if len(key) < c.keySize {
 		panic("ssh: key length too small for cipher")
 	}
@@ -54,6 +71,11 @@
 		return nil, err
 	}
 
+	var streamDump []byte
+	if c.skip > 0 {
+		streamDump = make([]byte, 512)
+	}
+
 	for remainingToDump := c.skip; remainingToDump > 0; {
 		dumpThisTime := remainingToDump
 		if dumpThisTime > len(streamDump) {
@@ -66,18 +88,10 @@
 	return stream, nil
 }
 
-// Specifies a default set of ciphers and a preference order. This is based on
-// OpenSSH's default client preference order, minus algorithms that are not
-// implemented.
-var DefaultCipherOrder = []string{
-	"aes128-ctr", "aes192-ctr", "aes256-ctr",
-	"arcfour256", "arcfour128",
-}
-
 // cipherModes documents properties of supported ciphers. Ciphers not included
 // are not supported and will not be negotiated, even if explicitly requested in
 // ClientConfig.Crypto.Ciphers.
-var cipherModes = map[string]*cipherMode{
+var cipherModes = map[string]*streamCipherMode{
 	// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
 	// are defined in the order specified in the RFC.
 	"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
@@ -88,13 +102,237 @@
 	// They are defined in the order specified in the RFC.
 	"arcfour128": {16, 0, 1536, newRC4},
 	"arcfour256": {32, 0, 1536, newRC4},
+
+	// AES-GCM is not a stream cipher, so it is constructed with a
+	// special case. If we add any more non-stream ciphers, we
+	// should invest a cleaner way to do this.
+	gcmCipherID: {16, 12, 0, nil},
 }
 
-// defaultKeyExchangeOrder specifies a default set of key exchange algorithms
-// with preferences.
-var defaultKeyExchangeOrder = []string{
-	// P384 and P521 are not constant-time yet, but since we don't
-	// reuse ephemeral keys, using them for ECDH should be OK.
-	kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
-	kexAlgoDH14SHA1, kexAlgoDH1SHA1,
+// prefixLen is the length of the packet prefix that contains the packet length
+// and number of padding bytes.
+const prefixLen = 5
+
+// streamPacketCipher is a packetCipher using a stream cipher.
+type streamPacketCipher struct {
+	mac    hash.Hash
+	cipher cipher.Stream
+
+	// The following members are to avoid per-packet allocations.
+	prefix      [prefixLen]byte
+	seqNumBytes [4]byte
+	padding     [2 * packetSizeMultiple]byte
+	packetData  []byte
+	macResult   []byte
+}
+
+// readPacket reads and decrypt a single packet from the reader argument.
+func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
+	if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
+		return nil, err
+	}
+
+	s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
+	length := binary.BigEndian.Uint32(s.prefix[0:4])
+	paddingLength := uint32(s.prefix[4])
+
+	var macSize uint32
+	if s.mac != nil {
+		s.mac.Reset()
+		binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
+		s.mac.Write(s.seqNumBytes[:])
+		s.mac.Write(s.prefix[:])
+		macSize = uint32(s.mac.Size())
+	}
+
+	if length <= paddingLength+1 {
+		return nil, errors.New("ssh: invalid packet length, packet too small")
+	}
+
+	if length > maxPacket {
+		return nil, errors.New("ssh: invalid packet length, packet too large")
+	}
+
+	// the maxPacket check above ensures that length-1+macSize
+	// does not overflow.
+	if uint32(cap(s.packetData)) < length-1+macSize {
+		s.packetData = make([]byte, length-1+macSize)
+	} else {
+		s.packetData = s.packetData[:length-1+macSize]
+	}
+
+	if _, err := io.ReadFull(r, s.packetData); err != nil {
+		return nil, err
+	}
+	mac := s.packetData[length-1:]
+	data := s.packetData[:length-1]
+	s.cipher.XORKeyStream(data, data)
+
+	if s.mac != nil {
+		s.mac.Write(data)
+		s.macResult = s.mac.Sum(s.macResult[:0])
+		if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
+			return nil, errors.New("ssh: MAC failure")
+		}
+	}
+
+	return s.packetData[:length-paddingLength-1], nil
+}
+
+// writePacket encrypts and sends a packet of data to the writer argument
+func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
+	if len(packet) > maxPacket {
+		return errors.New("ssh: packet too large")
+	}
+
+	paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple
+	if paddingLength < 4 {
+		paddingLength += packetSizeMultiple
+	}
+
+	length := len(packet) + 1 + paddingLength
+	binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
+	s.prefix[4] = byte(paddingLength)
+	padding := s.padding[:paddingLength]
+	if _, err := io.ReadFull(rand, padding); err != nil {
+		return err
+	}
+
+	if s.mac != nil {
+		s.mac.Reset()
+		binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
+		s.mac.Write(s.seqNumBytes[:])
+		s.mac.Write(s.prefix[:])
+		s.mac.Write(packet)
+		s.mac.Write(padding)
+	}
+
+	s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
+	s.cipher.XORKeyStream(packet, packet)
+	s.cipher.XORKeyStream(padding, padding)
+
+	if _, err := w.Write(s.prefix[:]); err != nil {
+		return err
+	}
+	if _, err := w.Write(packet); err != nil {
+		return err
+	}
+	if _, err := w.Write(padding); err != nil {
+		return err
+	}
+
+	if s.mac != nil {
+		s.macResult = s.mac.Sum(s.macResult[:0])
+		if _, err := w.Write(s.macResult); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+type gcmCipher struct {
+	aead   cipher.AEAD
+	prefix [4]byte
+	iv     []byte
+	buf    []byte
+}
+
+func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) {
+	c, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, err
+	}
+
+	aead, err := cipher.NewGCM(c)
+	if err != nil {
+		return nil, err
+	}
+
+	return &gcmCipher{
+		aead: aead,
+		iv:   iv,
+	}, nil
+}
+
+const gcmTagSize = 16
+
+func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
+	// Pad out to multiple of 16 bytes. This is different from the
+	// stream cipher because that encrypts the length too.
+	padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
+	if padding < 4 {
+		padding += packetSizeMultiple
+	}
+
+	length := uint32(len(packet) + int(padding) + 1)
+	binary.BigEndian.PutUint32(c.prefix[:], length)
+	if _, err := w.Write(c.prefix[:]); err != nil {
+		return err
+	}
+
+	if cap(c.buf) < int(length) {
+		c.buf = make([]byte, length)
+	} else {
+		c.buf = c.buf[:length]
+	}
+
+	c.buf[0] = padding
+	copy(c.buf[1:], packet)
+	if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
+		return err
+	}
+	c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
+	if _, err := w.Write(c.buf); err != nil {
+		return err
+	}
+	c.incIV()
+
+	return nil
+}
+
+func (c *gcmCipher) incIV() {
+	for i := 4 + 7; i >= 4; i-- {
+		c.iv[i]++
+		if c.iv[i] != 0 {
+			break
+		}
+	}
+}
+
+func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
+	if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
+		return nil, err
+	}
+	length := binary.BigEndian.Uint32(c.prefix[:])
+	if length > maxPacket {
+		return nil, errors.New("ssh: max packet length exceeded.")
+	}
+
+	if cap(c.buf) < int(length+gcmTagSize) {
+		c.buf = make([]byte, length+gcmTagSize)
+	} else {
+		c.buf = c.buf[:length+gcmTagSize]
+	}
+
+	if _, err := io.ReadFull(r, c.buf); err != nil {
+		return nil, err
+	}
+
+	plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
+	if err != nil {
+		return nil, err
+	}
+	c.incIV()
+
+	padding := plain[0]
+	if padding < 4 || padding >= 20 {
+		return nil, fmt.Errorf("ssh: illegal padding %d", padding)
+	}
+
+	if int(padding+1) >= len(plain) {
+		return nil, fmt.Errorf("ssh: padding %d too large", padding)
+	}
+	plain = plain[1 : length-uint32(padding)]
+	return plain, nil
 }
diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go
index ea27bd8..e279af0 100644
--- a/ssh/cipher_test.go
+++ b/ssh/cipher_test.go
@@ -6,57 +6,54 @@
 
 import (
 	"bytes"
+	"crypto"
+	"crypto/rand"
 	"testing"
 )
 
-// TestCipherReversal tests that each cipher factory produces ciphers that can
-// encrypt and decrypt some data successfully.
-func TestCipherReversal(t *testing.T) {
-	testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
-	testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
-	testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
-
-	cryptBuffer := make([]byte, 32)
-
-	for name, cipherMode := range cipherModes {
-		encrypter, err := cipherMode.createCipher(testKey, testIv)
-		if err != nil {
-			t.Errorf("failed to create encrypter for %q: %s", name, err)
-			continue
-		}
-		decrypter, err := cipherMode.createCipher(testKey, testIv)
-		if err != nil {
-			t.Errorf("failed to create decrypter for %q: %s", name, err)
-			continue
-		}
-
-		copy(cryptBuffer, testData)
-
-		encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
-		if name == "none" {
-			if !bytes.Equal(cryptBuffer, testData) {
-				t.Errorf("encryption made change with 'none' cipher")
-				continue
-			}
-		} else {
-			if bytes.Equal(cryptBuffer, testData) {
-				t.Errorf("encryption made no change with %q", name)
-				continue
-			}
-		}
-
-		decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
-		if !bytes.Equal(cryptBuffer, testData) {
-			t.Errorf("decrypted bytes not equal to input with %q", name)
-			continue
+func TestDefaultCiphersExist(t *testing.T) {
+	for _, cipherAlgo := range supportedCiphers {
+		if _, ok := cipherModes[cipherAlgo]; !ok {
+			t.Errorf("default cipher %q is unknown", cipherAlgo)
 		}
 	}
 }
 
-func TestDefaultCiphersExist(t *testing.T) {
-	for _, cipherAlgo := range DefaultCipherOrder {
-		if _, ok := cipherModes[cipherAlgo]; !ok {
-			t.Errorf("default cipher %q is unknown", cipherAlgo)
+func TestPacketCiphers(t *testing.T) {
+	for cipher := range cipherModes {
+		kr := &kexResult{Hash: crypto.SHA1}
+		algs := directionAlgorithms{
+			Cipher:      cipher,
+			MAC:         "hmac-sha1",
+			Compression: "none",
+		}
+		client, err := newPacketCipher(clientKeys, algs, kr)
+		if err != nil {
+			t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
+			continue
+		}
+		server, err := newPacketCipher(clientKeys, algs, kr)
+		if err != nil {
+			t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
+			continue
+		}
+
+		want := "bla bla"
+		input := []byte(want)
+		buf := &bytes.Buffer{}
+		if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
+			t.Errorf("writePacket(%q): %v", cipher, err)
+			continue
+		}
+
+		packet, err := server.readPacket(0, buf)
+		if err != nil {
+			t.Errorf("readPacket(%q): %v", cipher, err)
+			continue
+		}
+
+		if string(packet) != want {
+			t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
 		}
 	}
 }
diff --git a/ssh/client.go b/ssh/client.go
index e2d2557..a8d5235 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -5,403 +5,158 @@
 package ssh
 
 import (
-	"crypto/rand"
-	"encoding/binary"
 	"errors"
 	"fmt"
-	"io"
 	"net"
 	"sync"
 )
 
-// ClientConn represents the client side of an SSH connection.
-type ClientConn struct {
-	transport   *transport
-	config      *ClientConfig
-	chanList    // channels associated with this connection
-	forwardList // forwarded tcpip connections from the remote side
-	globalRequest
+// Client implements a traditional SSH client that supports shells,
+// subprocesses, port forwarding and tunneled dialing.
+type Client struct {
+	Conn
 
-	// Address as passed to the Dial function.
-	dialAddress string
-
-	serverVersion string
+	forwards        forwardList // forwarded tcpip connections from the remote side
+	mu              sync.Mutex
+	channelHandlers map[string]chan NewChannel
 }
 
-type globalRequest struct {
-	sync.Mutex
-	response chan interface{}
+// HandleChannelOpen returns a channel on which NewChannel requests
+// for the given type are sent. If the type already is being handled,
+// nil is returned. The channel is closed when the connection is closed.
+func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.channelHandlers == nil {
+		// The SSH channel has been closed.
+		c := make(chan NewChannel)
+		close(c)
+		return c
+	}
+
+	ch := c.channelHandlers[channelType]
+	if ch != nil {
+		return nil
+	}
+
+	ch = make(chan NewChannel, 16)
+	c.channelHandlers[channelType] = ch
+	return ch
 }
 
-// Client returns a new SSH client connection using c as the underlying transport.
-func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
-	return clientWithAddress(c, "", config)
+// NewClient creates a Client on top of the given connection.
+func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
+	conn := &Client{
+		Conn:            c,
+		channelHandlers: make(map[string]chan NewChannel, 1),
+	}
+
+	go conn.handleGlobalRequests(reqs)
+	go conn.handleChannelOpens(chans)
+	go func() {
+		conn.Wait()
+		conn.forwards.closeAll()
+	}()
+	go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
+	return conn
 }
 
-func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
-	conn := &ClientConn{
-		transport:     newTransport(c, config.rand(), true /* is client */),
-		config:        config,
-		globalRequest: globalRequest{response: make(chan interface{}, 1)},
-		dialAddress:   addr,
+// NewClientConn establishes an authenticated SSH connection using c
+// as the underlying transport.  The Request and NewChannel channels
+// must be serviced or the connection will hang.
+func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
+	fullConf := *config
+	fullConf.SetDefaults()
+	conn := &connection{
+		sshConn: sshConn{conn: c},
 	}
 
-	if err := conn.handshake(); err != nil {
-		conn.transport.Close()
-		return nil, fmt.Errorf("handshake failed: %v", err)
+	if err := conn.clientHandshake(addr, &fullConf); err != nil {
+		c.Close()
+		return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
 	}
-	go conn.mainLoop()
-	return conn, nil
+	conn.mux = newMux(conn.transport)
+	return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
 }
 
-// Close closes the connection.
-func (c *ClientConn) Close() error { return c.transport.Close() }
-
-// LocalAddr returns the local network address.
-func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
-
-// RemoteAddr returns the remote network address.
-func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
-
-// handshake performs the client side key exchange. See RFC 4253 Section 7.
-func (c *ClientConn) handshake() error {
-	clientVersion := []byte(packageVersion)
-	if c.config.ClientVersion != "" {
-		clientVersion = []byte(c.config.ClientVersion)
+// clientHandshake performs the client side key exchange. See RFC 4253 Section
+// 7.
+func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
+	c.clientVersion = []byte(packageVersion)
+	if config.ClientVersion != "" {
+		c.clientVersion = []byte(config.ClientVersion)
 	}
 
-	serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion)
-	if err != nil {
-		return err
-	}
-	c.serverVersion = string(serverVersion)
-
-	clientKexInit := kexInitMsg{
-		KexAlgos:                c.config.Crypto.kexes(),
-		ServerHostKeyAlgos:      supportedHostKeyAlgos,
-		CiphersClientServer:     c.config.Crypto.ciphers(),
-		CiphersServerClient:     c.config.Crypto.ciphers(),
-		MACsClientServer:        c.config.Crypto.macs(),
-		MACsServerClient:        c.config.Crypto.macs(),
-		CompressionClientServer: supportedCompressions,
-		CompressionServerClient: supportedCompressions,
-	}
-	kexInitPacket := marshal(msgKexInit, clientKexInit)
-	if err := c.transport.writePacket(kexInitPacket); err != nil {
-		return err
-	}
-	packet, err := c.transport.readPacket()
+	var err error
+	c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
 	if err != nil {
 		return err
 	}
 
-	var serverKexInit kexInitMsg
-	if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
+	c.transport = newClientTransport(
+		newTransport(c.sshConn.conn, config.Rand, true /* is client */),
+		c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
+	if err := c.transport.requestKeyChange(); err != nil {
 		return err
 	}
 
-	algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit)
-	if algs == nil {
-		return errors.New("ssh: no common algorithms")
-	}
-
-	if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
-		// The server sent a Kex message for the wrong algorithm,
-		// which we have to ignore.
-		if _, err := c.transport.readPacket(); err != nil {
-			return err
-		}
-	}
-
-	kex, ok := kexAlgoMap[algs.kex]
-	if !ok {
-		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
-	}
-
-	magics := handshakeMagics{
-		clientVersion: clientVersion,
-		serverVersion: serverVersion,
-		clientKexInit: kexInitPacket,
-		serverKexInit: packet,
-	}
-	result, err := kex.Client(c.transport, c.config.rand(), &magics)
-	if err != nil {
+	if packet, err := c.transport.readPacket(); err != nil {
 		return err
+	} else if packet[0] != msgNewKeys {
+		return unexpectedMessageError(msgNewKeys, packet[0])
 	}
-
-	err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, result.Signature)
-	if err != nil {
-		return err
-	}
-
-	if checker := c.config.HostKeyChecker; checker != nil {
-		err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
-		if err != nil {
-			return err
-		}
-	}
-
-	c.transport.prepareKeyChange(algs, result)
-
-	if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
-		return err
-	}
-	if packet, err = c.transport.readPacket(); err != nil {
-		return err
-	}
-	if packet[0] != msgNewKeys {
-		return UnexpectedMessageError{msgNewKeys, packet[0]}
-	}
-	return c.authenticate()
+	return c.clientAuthenticate(config)
 }
 
-// Verify the host key obtained in the key exchange.
-func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error {
-	hostKey, rest, ok := ParsePublicKey(hostKeyBytes)
-	if len(rest) > 0 || !ok {
-		return errors.New("ssh: could not parse hostkey")
-	}
-
-	sig, rest, ok := parseSignatureBody(signature)
+// verifyHostKeySignature verifies the host key obtained in the key
+// exchange.
+func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error {
+	sig, rest, ok := parseSignatureBody(result.Signature)
 	if len(rest) > 0 || !ok {
 		return errors.New("ssh: signature parse error")
 	}
-	if sig.Format != hostKeyAlgo {
-		return fmt.Errorf("ssh: unexpected signature type %q", sig.Format)
-	}
 
-	if !hostKey.Verify(data, sig.Blob) {
-		return errors.New("ssh: host key signature error")
-	}
-	return nil
+	return hostKey.Verify(result.H, sig)
 }
 
-// mainLoop reads incoming messages and routes channel messages
-// to their respective ClientChans.
-func (c *ClientConn) mainLoop() {
-	defer func() {
-		c.transport.Close()
-		c.chanList.closeAll()
-		c.forwardList.closeAll()
-	}()
-
-	for {
-		packet, err := c.transport.readPacket()
-		if err != nil {
-			break
-		}
-		// TODO(dfc) A note on blocking channel use.
-		// The msg, data and dataExt channels of a clientChan can
-		// cause this loop to block indefinitely if the consumer does
-		// not service them.
-		switch packet[0] {
-		case msgChannelData:
-			if len(packet) < 9 {
-				// malformed data packet
-				return
-			}
-			remoteId := binary.BigEndian.Uint32(packet[1:5])
-			length := binary.BigEndian.Uint32(packet[5:9])
-			packet = packet[9:]
-
-			if length != uint32(len(packet)) {
-				return
-			}
-			ch, ok := c.getChan(remoteId)
-			if !ok {
-				return
-			}
-			ch.stdout.write(packet)
-		case msgChannelExtendedData:
-			if len(packet) < 13 {
-				// malformed data packet
-				return
-			}
-			remoteId := binary.BigEndian.Uint32(packet[1:5])
-			datatype := binary.BigEndian.Uint32(packet[5:9])
-			length := binary.BigEndian.Uint32(packet[9:13])
-			packet = packet[13:]
-
-			if length != uint32(len(packet)) {
-				return
-			}
-			// RFC 4254 5.2 defines data_type_code 1 to be data destined
-			// for stderr on interactive sessions. Other data types are
-			// silently discarded.
-			if datatype == 1 {
-				ch, ok := c.getChan(remoteId)
-				if !ok {
-					return
-				}
-				ch.stderr.write(packet)
-			}
-		default:
-			decoded, err := decode(packet)
-			if err != nil {
-				if _, ok := err.(UnexpectedMessageError); ok {
-					fmt.Printf("mainLoop: unexpected message: %v\n", err)
-					continue
-				}
-				return
-			}
-			switch msg := decoded.(type) {
-			case *channelOpenMsg:
-				c.handleChanOpen(msg)
-			case *channelOpenConfirmMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.msg <- msg
-			case *channelOpenFailureMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.msg <- msg
-			case *channelCloseMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.Close()
-				close(ch.msg)
-				c.chanList.remove(msg.PeersId)
-			case *channelEOFMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.stdout.eof()
-				// RFC 4254 is mute on how EOF affects dataExt messages but
-				// it is logical to signal EOF at the same time.
-				ch.stderr.eof()
-			case *channelRequestSuccessMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.msg <- msg
-			case *channelRequestFailureMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.msg <- msg
-			case *channelRequestMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				ch.msg <- msg
-			case *windowAdjustMsg:
-				ch, ok := c.getChan(msg.PeersId)
-				if !ok {
-					return
-				}
-				if !ch.remoteWin.add(msg.AdditionalBytes) {
-					// invalid window update
-					return
-				}
-			case *globalRequestMsg:
-				// This handles keepalive messages and matches
-				// the behaviour of OpenSSH.
-				if msg.WantReply {
-					c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
-				}
-			case *globalRequestSuccessMsg, *globalRequestFailureMsg:
-				c.globalRequest.response <- msg
-			case *disconnectMsg:
-				return
-			default:
-				fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
-			}
-		}
-	}
-}
-
-// Handle channel open messages from the remote side.
-func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
-	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
-		c.sendConnectionFailed(msg.PeersId)
-	}
-
-	switch msg.ChanType {
-	case "forwarded-tcpip":
-		laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
-		if !ok {
-			// invalid request
-			c.sendConnectionFailed(msg.PeersId)
-			return
-		}
-
-		l, ok := c.forwardList.lookup(*laddr)
-		if !ok {
-			// TODO: print on a more structured log.
-			fmt.Println("could not find forward list entry for", laddr)
-			// Section 7.2, implementations MUST reject spurious incoming
-			// connections.
-			c.sendConnectionFailed(msg.PeersId)
-			return
-		}
-		raddr, rest, ok := parseTCPAddr(rest)
-		if !ok {
-			// invalid request
-			c.sendConnectionFailed(msg.PeersId)
-			return
-		}
-		ch := c.newChan(c.transport)
-		ch.remoteId = msg.PeersId
-		ch.remoteWin.add(msg.PeersWindow)
-		ch.maxPacket = msg.MaxPacketSize
-
-		m := channelOpenConfirmMsg{
-			PeersId:       ch.remoteId,
-			MyId:          ch.localId,
-			MyWindow:      channelWindowSize,
-			MaxPacketSize: channelMaxPacketSize,
-		}
-
-		c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
-		l <- forward{ch, raddr}
-	default:
-		// unknown channel type
-		m := channelOpenFailureMsg{
-			PeersId:  msg.PeersId,
-			Reason:   UnknownChannelType,
-			Message:  fmt.Sprintf("unknown channel type: %v", msg.ChanType),
-			Language: "en_US.UTF-8",
-		}
-		c.transport.writePacket(marshal(msgChannelOpenFailure, m))
-	}
-}
-
-// sendGlobalRequest sends a global request message as specified
-// in RFC4254 section 4. To correctly synchronise messages, a lock
-// is held internally until a response is returned.
-func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
-	c.globalRequest.Lock()
-	defer c.globalRequest.Unlock()
-	if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+// NewSession opens a new Session for this client. (A session is a remote
+// execution of a program.)
+func (c *Client) NewSession() (*Session, error) {
+	ch, in, err := c.OpenChannel("session", nil)
+	if err != nil {
 		return nil, err
 	}
-	r := <-c.globalRequest.response
-	if r, ok := r.(*globalRequestSuccessMsg); ok {
-		return r, nil
-	}
-	return nil, errors.New("request failed")
+	return newSession(ch, in)
 }
 
-// sendConnectionFailed rejects an incoming channel identified
-// by remoteId.
-func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
-	m := channelOpenFailureMsg{
-		PeersId:  remoteId,
-		Reason:   ConnectionFailed,
-		Message:  "invalid request",
-		Language: "en_US.UTF-8",
+func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
+	for r := range incoming {
+		// This handles keepalive messages and matches
+		// the behaviour of OpenSSH.
+		r.Reply(false, nil)
 	}
-	return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
+}
+
+// handleChannelOpens channel open messages from the remote side.
+func (c *Client) handleChannelOpens(in <-chan NewChannel) {
+	for ch := range in {
+		c.mu.Lock()
+		handler := c.channelHandlers[ch.ChannelType()]
+		c.mu.Unlock()
+
+		if handler != nil {
+			handler <- ch
+		} else {
+			ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
+		}
+	}
+
+	c.mu.Lock()
+	for _, ch := range c.channelHandlers {
+		close(ch)
+	}
+	c.channelHandlers = nil
+	c.mu.Unlock()
 }
 
 // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
@@ -413,7 +168,7 @@
 		return nil, b, false
 	}
 	port, b, ok := parseUint32(b)
-	if !ok {
+	if !ok || port == 0 || port > 65535 {
 		return nil, b, false
 	}
 	ip := net.ParseIP(string(addr))
@@ -423,102 +178,44 @@
 	return &net.TCPAddr{IP: ip, Port: int(port)}, b, true
 }
 
-// Dial connects to the given network address using net.Dial and
-// then initiates a SSH handshake, returning the resulting client connection.
-func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
+// Dial starts a client connection to the given SSH server. It is a
+// convenience function that connects to the given network address,
+// initiates the SSH handshake, and then sets up a Client.  For access
+// to incoming channels and requests, use net.Dial with NewClientConn
+// instead.
+func Dial(network, addr string, config *ClientConfig) (*Client, error) {
 	conn, err := net.Dial(network, addr)
 	if err != nil {
 		return nil, err
 	}
-	return clientWithAddress(conn, addr, config)
+	c, chans, reqs, err := NewClientConn(conn, addr, config)
+	if err != nil {
+		return nil, err
+	}
+	return NewClient(c, chans, reqs), nil
 }
 
-// A ClientConfig structure is used to configure a ClientConn. After one has
-// been passed to an SSH function it must not be modified.
+// A ClientConfig structure is used to configure a Client. It must not be
+// modified after having been passed to an SSH function.
 type ClientConfig struct {
-	// Rand provides the source of entropy for key exchange. If Rand is
-	// nil, the cryptographic random reader in package crypto/rand will
-	// be used.
-	Rand io.Reader
+	// Config contains configuration that is shared between clients and
+	// servers.
+	Config
 
-	// The username to authenticate.
+	// User contains the username to authenticate as.
 	User string
 
-	// A slice of ClientAuth methods. Only the first instance
-	// of a particular RFC 4252 method will be used during authentication.
-	Auth []ClientAuth
+	// Auth contains possible authentication methods to use with the
+	// server. Only the first instance of a particular RFC 4252 method will
+	// be used during authentication.
+	Auth []AuthMethod
 
-	// HostKeyChecker, if not nil, is called during the cryptographic
-	// handshake to validate the server's host key. A nil HostKeyChecker
+	// HostKeyCallback, if not nil, is called during the cryptographic
+	// handshake to validate the server's host key. A nil HostKeyCallback
 	// implies that all host keys are accepted.
-	HostKeyChecker HostKeyChecker
+	HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
 
-	// Cryptographic-related configuration.
-	Crypto CryptoConfig
-
-	// The identification string that will be used for the connection.
-	// If empty, a reasonable default is used.
+	// ClientVersion contains the version identification string that will
+	// be used for the connection. If empty, a reasonable default is used.
 	ClientVersion string
 }
-
-func (c *ClientConfig) rand() io.Reader {
-	if c.Rand == nil {
-		return rand.Reader
-	}
-	return c.Rand
-}
-
-// Thread safe channel list.
-type chanList struct {
-	// protects concurrent access to chans
-	sync.Mutex
-	// chans are indexed by the local id of the channel, clientChan.localId.
-	// The PeersId value of messages received by ClientConn.mainLoop is
-	// used to locate the right local clientChan in this slice.
-	chans []*clientChan
-}
-
-// Allocate a new ClientChan with the next avail local id.
-func (c *chanList) newChan(p packetConn) *clientChan {
-	c.Lock()
-	defer c.Unlock()
-	for i := range c.chans {
-		if c.chans[i] == nil {
-			ch := newClientChan(p, uint32(i))
-			c.chans[i] = ch
-			return ch
-		}
-	}
-	i := len(c.chans)
-	ch := newClientChan(p, uint32(i))
-	c.chans = append(c.chans, ch)
-	return ch
-}
-
-func (c *chanList) getChan(id uint32) (*clientChan, bool) {
-	c.Lock()
-	defer c.Unlock()
-	if id >= uint32(len(c.chans)) {
-		return nil, false
-	}
-	return c.chans[id], true
-}
-
-func (c *chanList) remove(id uint32) {
-	c.Lock()
-	defer c.Unlock()
-	c.chans[id] = nil
-}
-
-func (c *chanList) closeAll() {
-	c.Lock()
-	defer c.Unlock()
-
-	for _, ch := range c.chans {
-		if ch == nil {
-			continue
-		}
-		ch.Close()
-		close(ch.msg)
-	}
-}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 29be0ca..5b7aa30 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -5,16 +5,16 @@
 package ssh
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"io"
-	"net"
 )
 
-// authenticate authenticates with the remote server. See RFC 4252.
-func (c *ClientConn) authenticate() error {
+// clientAuthenticate authenticates with the remote server. See RFC 4252.
+func (c *connection) clientAuthenticate(config *ClientConfig) error {
 	// initiate user auth session
-	if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+	if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
 		return err
 	}
 	packet, err := c.transport.readPacket()
@@ -22,14 +22,15 @@
 		return err
 	}
 	var serviceAccept serviceAcceptMsg
-	if err := unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
+	if err := Unmarshal(packet, &serviceAccept); err != nil {
 		return err
 	}
+
 	// during the authentication phase the client first attempts the "none" method
 	// then any untried methods suggested by the server.
-	tried, remain := make(map[string]bool), make(map[string]bool)
-	for auth := ClientAuth(new(noneAuth)); auth != nil; {
-		ok, methods, err := auth.auth(c.transport.sessionID, c.config.User, c.transport, c.config.rand())
+	tried := make(map[string]bool)
+	for auth := AuthMethod(new(noneAuth)); auth != nil; {
+		ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
 		if err != nil {
 			return err
 		}
@@ -38,45 +39,35 @@
 			return nil
 		}
 		tried[auth.method()] = true
-		delete(remain, auth.method())
-		for _, meth := range methods {
-			if tried[meth] {
-				// if we've tried meth already, skip it.
-				continue
-			}
-			remain[meth] = true
-		}
+
 		auth = nil
-		for _, a := range c.config.Auth {
-			if remain[a.method()] {
-				auth = a
-				break
+		for _, a := range config.Auth {
+			candidateMethod := a.method()
+			for _, meth := range methods {
+				if meth != candidateMethod {
+					continue
+				}
+				if !tried[meth] {
+					auth = a
+					break
+				}
 			}
 		}
 	}
 	return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried))
 }
 
-func keys(m map[string]bool) (s []string) {
-	for k := range m {
-		s = append(s, k)
+func keys(m map[string]bool) []string {
+	s := make([]string, 0, len(m))
+
+	for key := range m {
+		s = append(s, key)
 	}
-	return
+	return s
 }
 
-// HostKeyChecker represents a database of known server host keys.
-type HostKeyChecker interface {
-	// Check is called during the handshake to check server's
-	// public key for unexpected changes. The hostKey argument is
-	// in SSH wire format. It can be parsed using
-	// ssh.ParsePublicKey. The address before DNS resolution is
-	// passed in the addr argument, so the key can also be checked
-	// against the hostname.
-	Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
-}
-
-// A ClientAuth represents an instance of an RFC 4252 authentication method.
-type ClientAuth interface {
+// An AuthMethod represents an instance of an RFC 4252 authentication method.
+type AuthMethod interface {
 	// auth authenticates user over transport t.
 	// Returns true if authentication is successful.
 	// If authentication is not successful, a []string of alternative
@@ -91,7 +82,7 @@
 type noneAuth int
 
 func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
-	if err := c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
+	if err := c.writePacket(Marshal(&userAuthRequestMsg{
 		User:    user,
 		Service: serviceSSH,
 		Method:  "none",
@@ -106,29 +97,31 @@
 	return "none"
 }
 
-// "password" authentication, RFC 4252 Section 8.
-type passwordAuth struct {
-	ClientPassword
-}
+// passwordCallback is an AuthMethod that fetches the password through
+// a function call, e.g. by prompting the user.
+type passwordCallback func() (password string, err error)
 
-func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
 	type passwordAuthMsg struct {
-		User     string
+		User     string `sshtype:"50"`
 		Service  string
 		Method   string
 		Reply    bool
 		Password string
 	}
 
-	pw, err := p.Password(user)
+	pw, err := cb()
+	// REVIEW NOTE: is there a need to support skipping a password attempt?
+	// The program may only find out that the user doesn't have a password
+	// when prompting.
 	if err != nil {
 		return false, nil, err
 	}
 
-	if err := c.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
+	if err := c.writePacket(Marshal(&passwordAuthMsg{
 		User:     user,
 		Service:  serviceSSH,
-		Method:   "password",
+		Method:   cb.method(),
 		Reply:    false,
 		Password: pw,
 	})); err != nil {
@@ -138,106 +131,93 @@
 	return handleAuthResponse(c)
 }
 
-func (p *passwordAuth) method() string {
+func (cb passwordCallback) method() string {
 	return "password"
 }
 
-// A ClientPassword implements access to a client's passwords.
-type ClientPassword interface {
-	// Password returns the password to use for user.
-	Password(user string) (password string, err error)
+// Password returns an AuthMethod using the given password.
+func Password(secret string) AuthMethod {
+	return passwordCallback(func() (string, error) { return secret, nil })
 }
 
-// ClientAuthPassword returns a ClientAuth using password authentication.
-func ClientAuthPassword(impl ClientPassword) ClientAuth {
-	return &passwordAuth{impl}
-}
-
-// ClientKeyring implements access to a client key ring.
-type ClientKeyring interface {
-	// Key returns the i'th Publickey, or nil if no key exists at i.
-	Key(i int) (key PublicKey, err error)
-
-	// Sign returns a signature of the given data using the i'th key
-	// and the supplied random source.
-	Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
-}
-
-// "publickey" authentication, RFC 4252 Section 7.
-type publickeyAuth struct {
-	ClientKeyring
+// PasswordCallback returns an AuthMethod that uses a callback for
+// fetching a password.
+func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
+	return passwordCallback(prompt)
 }
 
 type publickeyAuthMsg struct {
-	User    string
+	User    string `sshtype:"50"`
 	Service string
 	Method  string
 	// HasSig indicates to the receiver packet that the auth request is signed and
 	// should be used for authentication of the request.
 	HasSig   bool
 	Algoname string
-	Pubkey   string
-	// Sig is defined as []byte so marshal will exclude it during validateKey
+	PubKey   []byte
+	// Sig is tagged with "rest" so Marshal will exclude it during
+	// validateKey
 	Sig []byte `ssh:"rest"`
 }
 
-func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+// publicKeyCallback is an AuthMethod that uses a set of key
+// pairs for authentication.
+type publicKeyCallback func() ([]Signer, error)
+
+func (cb publicKeyCallback) method() string {
+	return "publickey"
+}
+
+func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
 	// Authentication is performed in two stages. The first stage sends an
 	// enquiry to test if each key is acceptable to the remote. The second
 	// stage attempts to authenticate with the valid keys obtained in the
 	// first stage.
 
-	var index int
-	// a map of public keys to their index in the keyring
-	validKeys := make(map[int]PublicKey)
-	for {
-		key, err := p.Key(index)
-		if err != nil {
-			return false, nil, err
-		}
-		if key == nil {
-			// no more keys in the keyring
-			break
-		}
-
-		if ok, err := p.validateKey(key, user, c); ok {
-			validKeys[index] = key
+	signers, err := cb()
+	if err != nil {
+		return false, nil, err
+	}
+	var validKeys []Signer
+	for _, signer := range signers {
+		if ok, err := validateKey(signer.PublicKey(), user, c); ok {
+			validKeys = append(validKeys, signer)
 		} else {
 			if err != nil {
 				return false, nil, err
 			}
 		}
-		index++
 	}
 
 	// methods that may continue if this auth is not successful.
 	var methods []string
-	for i, key := range validKeys {
-		pubkey := MarshalPublicKey(key)
-		algoname := key.PublicKeyAlgo()
-		data := buildDataSignedForAuth(session, userAuthRequestMsg{
+	for _, signer := range validKeys {
+		pub := signer.PublicKey()
+
+		pubKey := pub.Marshal()
+		sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
 			User:    user,
 			Service: serviceSSH,
-			Method:  p.method(),
-		}, []byte(algoname), pubkey)
-		sigBlob, err := p.Sign(i, rand, data)
+			Method:  cb.method(),
+		}, []byte(pub.Type()), pubKey))
 		if err != nil {
 			return false, nil, err
 		}
+
 		// manually wrap the serialized signature in a string
-		s := serializeSignature(key.PublicKeyAlgo(), sigBlob)
+		s := Marshal(sign)
 		sig := make([]byte, stringLength(len(s)))
 		marshalString(sig, s)
 		msg := publickeyAuthMsg{
 			User:     user,
 			Service:  serviceSSH,
-			Method:   p.method(),
+			Method:   cb.method(),
 			HasSig:   true,
-			Algoname: algoname,
-			Pubkey:   string(pubkey),
+			Algoname: pub.Type(),
+			PubKey:   pubKey,
 			Sig:      sig,
 		}
-		p := marshal(msgUserAuthRequest, msg)
+		p := Marshal(&msg)
 		if err := c.writePacket(p); err != nil {
 			return false, nil, err
 		}
@@ -252,28 +232,27 @@
 	return false, methods, nil
 }
 
-// validateKey validates the key provided it is acceptable to the server.
-func (p *publickeyAuth) validateKey(key PublicKey, user string, c packetConn) (bool, error) {
-	pubkey := MarshalPublicKey(key)
-	algoname := key.PublicKeyAlgo()
+// validateKey validates the key provided is acceptable to the server.
+func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
+	pubKey := key.Marshal()
 	msg := publickeyAuthMsg{
 		User:     user,
 		Service:  serviceSSH,
-		Method:   p.method(),
+		Method:   "publickey",
 		HasSig:   false,
-		Algoname: algoname,
-		Pubkey:   string(pubkey),
+		Algoname: key.Type(),
+		PubKey:   pubKey,
 	}
-	if err := c.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+	if err := c.writePacket(Marshal(&msg)); err != nil {
 		return false, err
 	}
 
-	return p.confirmKeyAck(key, c)
+	return confirmKeyAck(key, c)
 }
 
-func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
-	pubkey := MarshalPublicKey(key)
-	algoname := key.PublicKeyAlgo()
+func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
+	pubKey := key.Marshal()
+	algoname := key.Type()
 
 	for {
 		packet, err := c.readPacket()
@@ -284,30 +263,32 @@
 		case msgUserAuthBanner:
 			// TODO(gpaul): add callback to present the banner to the user
 		case msgUserAuthPubKeyOk:
-			msg := userAuthPubKeyOkMsg{}
-			if err := unmarshal(&msg, packet, msgUserAuthPubKeyOk); err != nil {
+			var msg userAuthPubKeyOkMsg
+			if err := Unmarshal(packet, &msg); err != nil {
 				return false, err
 			}
-			if msg.Algo != algoname || msg.PubKey != string(pubkey) {
+			if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) {
 				return false, nil
 			}
 			return true, nil
 		case msgUserAuthFailure:
 			return false, nil
 		default:
-			return false, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+			return false, unexpectedMessageError(msgUserAuthSuccess, packet[0])
 		}
 	}
-	panic("unreachable")
 }
 
-func (p *publickeyAuth) method() string {
-	return "publickey"
+// PublicKeys returns an AuthMethod that uses the given key
+// pairs.
+func PublicKeys(signers ...Signer) AuthMethod {
+	return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
 }
 
-// ClientAuthKeyring returns a ClientAuth using public key authentication.
-func ClientAuthKeyring(impl ClientKeyring) ClientAuth {
-	return &publickeyAuth{impl}
+// PublicKeysCallback returns an AuthMethod that runs the given
+// function to obtain a list of key pairs.
+func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
+	return publicKeyCallback(getSigners)
 }
 
 // handleAuthResponse returns whether the preceding authentication request succeeded
@@ -324,8 +305,8 @@
 		case msgUserAuthBanner:
 			// TODO: add callback to present the banner to the user
 		case msgUserAuthFailure:
-			msg := userAuthFailureMsg{}
-			if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
+			var msg userAuthFailureMsg
+			if err := Unmarshal(packet, &msg); err != nil {
 				return false, nil, err
 			}
 			return false, msg.Methods, nil
@@ -334,98 +315,40 @@
 		case msgDisconnect:
 			return false, nil, io.EOF
 		default:
-			return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+			return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
 		}
 	}
-	panic("unreachable")
 }
 
-// ClientAuthAgent returns a ClientAuth using public key authentication via
-// an agent.
-func ClientAuthAgent(agent *AgentClient) ClientAuth {
-	return ClientAuthKeyring(&agentKeyring{agent: agent})
+// KeyboardInteractiveChallenge should print questions, optionally
+// disabling echoing (e.g. for passwords), and return all the answers.
+// Challenge may be called multiple times in a single session. After
+// successful authentication, the server may send a challenge with no
+// questions, for which the user and instruction messages should be
+// printed.  RFC 4256 section 3.3 details how the UI should behave for
+// both CLI and GUI environments.
+type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
+
+// KeyboardInteractive returns a AuthMethod using a prompt/response
+// sequence controlled by the server.
+func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
+	return challenge
 }
 
-// agentKeyring implements ClientKeyring.
-type agentKeyring struct {
-	agent *AgentClient
-	keys  []*AgentKey
-}
-
-func (kr *agentKeyring) Key(i int) (key PublicKey, err error) {
-	if kr.keys == nil {
-		if kr.keys, err = kr.agent.RequestIdentities(); err != nil {
-			return
-		}
-	}
-	if i >= len(kr.keys) {
-		return
-	}
-	return kr.keys[i].Key()
-}
-
-func (kr *agentKeyring) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
-	var key PublicKey
-	if key, err = kr.Key(i); err != nil {
-		return
-	}
-	if key == nil {
-		return nil, errors.New("ssh: key index out of range")
-	}
-	if sig, err = kr.agent.SignRequest(key, data); err != nil {
-		return
-	}
-
-	// Unmarshal the signature.
-
-	var ok bool
-	if _, sig, ok = parseString(sig); !ok {
-		return nil, errors.New("ssh: malformed signature response from agent")
-	}
-	if sig, _, ok = parseString(sig); !ok {
-		return nil, errors.New("ssh: malformed signature response from agent")
-	}
-	return sig, nil
-}
-
-// ClientKeyboardInteractive should prompt the user for the given
-// questions.
-type ClientKeyboardInteractive interface {
-	// Challenge should print the questions, optionally disabling
-	// echoing (eg. for passwords), and return all the answers.
-	// Challenge may be called multiple times in a single
-	// session. After successful authentication, the server may
-	// send a challenge with no questions, for which the user and
-	// instruction messages should be printed.  RFC 4256 section
-	// 3.3 details how the UI should behave for both CLI and
-	// GUI environments.
-	Challenge(user, instruction string, questions []string, echos []bool) ([]string, error)
-}
-
-// ClientAuthKeyboardInteractive returns a ClientAuth using a
-// prompt/response sequence controlled by the server.
-func ClientAuthKeyboardInteractive(impl ClientKeyboardInteractive) ClientAuth {
-	return &keyboardInteractiveAuth{impl}
-}
-
-type keyboardInteractiveAuth struct {
-	ClientKeyboardInteractive
-}
-
-func (k *keyboardInteractiveAuth) method() string {
+func (cb KeyboardInteractiveChallenge) method() string {
 	return "keyboard-interactive"
 }
 
-func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
 	type initiateMsg struct {
-		User       string
+		User       string `sshtype:"50"`
 		Service    string
 		Method     string
 		Language   string
 		Submethods string
 	}
 
-	if err := c.writePacket(marshal(msgUserAuthRequest, initiateMsg{
+	if err := c.writePacket(Marshal(&initiateMsg{
 		User:    user,
 		Service: serviceSSH,
 		Method:  "keyboard-interactive",
@@ -448,18 +371,18 @@
 			// OK
 		case msgUserAuthFailure:
 			var msg userAuthFailureMsg
-			if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
+			if err := Unmarshal(packet, &msg); err != nil {
 				return false, nil, err
 			}
 			return false, msg.Methods, nil
 		case msgUserAuthSuccess:
 			return true, nil, nil
 		default:
-			return false, nil, UnexpectedMessageError{msgUserAuthInfoRequest, packet[0]}
+			return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
 		}
 
 		var msg userAuthInfoRequestMsg
-		if err := unmarshal(&msg, packet, packet[0]); err != nil {
+		if err := Unmarshal(packet, &msg); err != nil {
 			return false, nil, err
 		}
 
@@ -478,10 +401,10 @@
 		}
 
 		if len(rest) != 0 {
-			return false, nil, fmt.Errorf("ssh: junk following message %q", rest)
+			return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
 		}
 
-		answers, err := k.Challenge(msg.User, msg.Instruction, prompts, echos)
+		answers, err := cb(msg.User, msg.Instruction, prompts, echos)
 		if err != nil {
 			return false, nil, err
 		}
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index f2fc9c6..3173255 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -6,363 +6,317 @@
 
 import (
 	"bytes"
-	"crypto/dsa"
-	"io"
-	"io/ioutil"
-	"math/big"
+	"crypto/rand"
+	"errors"
+	"fmt"
 	"strings"
 	"testing"
-
-	_ "crypto/sha1"
 )
 
-// private key for mock server
-const testServerPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
-MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
-70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
-9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
-tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
-s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
-qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
-+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
-riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
-D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
-atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
-b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
-ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
-MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
-KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
-e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
-D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
-3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
-orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
-64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
-XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
-QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
-/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
-I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
-gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
-NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
------END RSA PRIVATE KEY-----`
-
-const testClientPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
-MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
-r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
-tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
-nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
-2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
-y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
-rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
------END RSA PRIVATE KEY-----`
-
-// keychain implements the ClientKeyring interface
-type keychain struct {
-	keys []Signer
-}
-
-func (k *keychain) Key(i int) (PublicKey, error) {
-	if i < 0 || i >= len(k.keys) {
-		return nil, nil
-	}
-
-	return k.keys[i].PublicKey(), nil
-}
-
-func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
-	return k.keys[i].Sign(rand, data)
-}
-
-func (k *keychain) add(key Signer) {
-	k.keys = append(k.keys, key)
-}
-
-func (k *keychain) loadPEM(file string) error {
-	buf, err := ioutil.ReadFile(file)
-	if err != nil {
-		return err
-	}
-	key, err := ParsePrivateKey(buf)
-	if err != nil {
-		return err
-	}
-	k.add(key)
-	return nil
-}
-
-// password implements the ClientPassword interface
-type password string
-
-func (p password) Password(user string) (string, error) {
-	return string(p), nil
-}
-
 type keyboardInteractive map[string]string
 
-func (cr *keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
+func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
 	var answers []string
 	for _, q := range questions {
-		answers = append(answers, (*cr)[q])
+		answers = append(answers, cr[q])
 	}
 	return answers, nil
 }
 
 // reused internally by tests
-var (
-	rsaKey         Signer
-	dsaKey         Signer
-	clientKeychain = new(keychain)
-	clientPassword = password("tiger")
-	serverConfig   = &ServerConfig{
-		PasswordCallback: func(conn *ServerConn, user, pass string) bool {
-			return user == "testuser" && pass == string(clientPassword)
+var clientPassword = "tiger"
+
+// tryAuth runs a handshake with a given config against an SSH server
+// with config serverConfig
+func tryAuth(t *testing.T, config *ClientConfig) error {
+	c1, c2, err := netPipe()
+	if err != nil {
+		t.Fatalf("netPipe: %v", err)
+	}
+	defer c1.Close()
+	defer c2.Close()
+
+	certChecker := CertChecker{
+		IsAuthority: func(k PublicKey) bool {
+			return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
 		},
-		PublicKeyCallback: func(conn *ServerConn, user, algo string, pubkey []byte) bool {
-			key, _ := clientKeychain.Key(0)
-			expected := MarshalPublicKey(key)
-			algoname := key.PublicKeyAlgo()
-			return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
+		UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+			if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+				return nil, nil
+			}
+
+			return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
 		},
-		KeyboardInteractiveCallback: func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool {
-			ans, err := client.Challenge("user",
+		IsRevoked: func(c *Certificate) bool {
+			return c.Serial == 666
+		},
+	}
+
+	serverConfig := &ServerConfig{
+		PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
+			if conn.User() == "testuser" && string(pass) == clientPassword {
+				return nil, nil
+			}
+			return nil, errors.New("password auth failed")
+		},
+		PublicKeyCallback: certChecker.Authenticate,
+		KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) {
+			ans, err := challenge("user",
 				"instruction",
 				[]string{"question1", "question2"},
 				[]bool{true, true})
 			if err != nil {
-				return false
+				return nil, err
 			}
-			ok := user == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
-			client.Challenge("user", "motd", nil, nil)
-			return ok
+			ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
+			if ok {
+				challenge("user", "motd", nil, nil)
+				return nil, nil
+			}
+			return nil, errors.New("keyboard-interactive failed")
+		},
+		AuthLogCallback: func(conn ConnMetadata, method string, err error) {
+			t.Logf("user %q, method %q: %v", conn.User(), method, err)
 		},
 	}
-)
+	serverConfig.AddHostKey(testSigners["rsa"])
 
-func init() {
-	var err error
-	rsaKey, err = ParsePrivateKey([]byte(testServerPrivateKey))
-	if err != nil {
-		panic("unable to set private key: " + err.Error())
-	}
-	rawDSAKey := new(dsa.PrivateKey)
-
-	// taken from crypto/dsa/dsa_test.go
-	rawDSAKey.P, _ = new(big.Int).SetString("A9B5B793FB4785793D246BAE77E8FF63CA52F442DA763C440259919FE1BC1D6065A9350637A04F75A2F039401D49F08E066C4D275A5A65DA5684BC563C14289D7AB8A67163BFBF79D85972619AD2CFF55AB0EE77A9002B0EF96293BDD0F42685EBB2C66C327079F6C98000FBCB79AACDE1BC6F9D5C7B1A97E3D9D54ED7951FEF", 16)
-	rawDSAKey.Q, _ = new(big.Int).SetString("E1D3391245933D68A0714ED34BBCB7A1F422B9C1", 16)
-	rawDSAKey.G, _ = new(big.Int).SetString("634364FC25248933D01D1993ECABD0657CC0CB2CEED7ED2E3E8AECDFCDC4A25C3B15E9E3B163ACA2984B5539181F3EFF1A5E8903D71D5B95DA4F27202B77D2C44B430BB53741A8D59A8F86887525C9F2A6A5980A195EAA7F2FF910064301DEF89D3AA213E1FAC7768D89365318E370AF54A112EFBA9246D9158386BA1B4EEFDA", 16)
-	rawDSAKey.Y, _ = new(big.Int).SetString("32969E5780CFE1C849A1C276D7AEB4F38A23B591739AA2FE197349AEEBD31366AEE5EB7E6C6DDB7C57D02432B30DB5AA66D9884299FAA72568944E4EEDC92EA3FBC6F39F53412FBCC563208F7C15B737AC8910DBC2D9C9B8C001E72FDC40EB694AB1F06A5A2DBD18D9E36C66F31F566742F11EC0A52E9F7B89355C02FB5D32D2", 16)
-	rawDSAKey.X, _ = new(big.Int).SetString("5078D4D29795CBE76D3AACFE48C9AF0BCDBEE91A", 16)
-
-	dsaKey, err = NewSignerFromKey(rawDSAKey)
-	if err != nil {
-		panic("NewSignerFromKey: " + err.Error())
-	}
-	clientKeychain.add(rsaKey)
-	serverConfig.AddHostKey(rsaKey)
-}
-
-// newMockAuthServer creates a new Server bound to
-// the loopback interface. The server exits after
-// processing one handshake.
-func newMockAuthServer(t *testing.T) string {
-	l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
-	if err != nil {
-		t.Fatalf("unable to newMockAuthServer: %s", err)
-	}
-	go func() {
-		defer l.Close()
-		c, err := l.Accept()
-		if err != nil {
-			t.Errorf("Unable to accept incoming connection: %v", err)
-			return
-		}
-		if err := c.Handshake(); err != nil {
-			// not Errorf because this is expected to
-			// fail for some tests.
-			t.Logf("Handshaking error: %v", err)
-			return
-		}
-		defer c.Close()
-	}()
-	return l.Addr().String()
+	go newServer(c1, serverConfig)
+	_, _, _, err = NewClientConn(c2, "", config)
+	return err
 }
 
 func TestClientAuthPublicKey(t *testing.T) {
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyring(clientKeychain),
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["rsa"]),
 		},
 	}
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err != nil {
+	if err := tryAuth(t, config); err != nil {
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
-	c.Close()
 }
 
-func TestClientAuthPassword(t *testing.T) {
+func TestAuthMethodPassword(t *testing.T) {
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthPassword(clientPassword),
+		Auth: []AuthMethod{
+			Password(clientPassword),
 		},
 	}
 
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err != nil {
+	if err := tryAuth(t, config); err != nil {
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
-	c.Close()
 }
 
-func TestClientAuthWrongPassword(t *testing.T) {
-	wrongPw := password("wrong")
+func TestAuthMethodWrongPassword(t *testing.T) {
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthPassword(wrongPw),
-			ClientAuthKeyring(clientKeychain),
+		Auth: []AuthMethod{
+			Password("wrong"),
+			PublicKeys(testSigners["rsa"]),
 		},
 	}
 
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err != nil {
+	if err := tryAuth(t, config); err != nil {
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
-	c.Close()
 }
 
-func TestClientAuthKeyboardInteractive(t *testing.T) {
+func TestAuthMethodKeyboardInteractive(t *testing.T) {
 	answers := keyboardInteractive(map[string]string{
 		"question1": "answer1",
 		"question2": "answer2",
 	})
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyboardInteractive(&answers),
+		Auth: []AuthMethod{
+			KeyboardInteractive(answers.Challenge),
 		},
 	}
 
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err != nil {
+	if err := tryAuth(t, config); err != nil {
 		t.Fatalf("unable to dial remote side: %s", err)
 	}
-	c.Close()
 }
 
-func TestClientAuthWrongKeyboardInteractive(t *testing.T) {
+func TestAuthMethodWrongKeyboardInteractive(t *testing.T) {
 	answers := keyboardInteractive(map[string]string{
 		"question1": "answer1",
 		"question2": "WRONG",
 	})
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyboardInteractive(&answers),
+		Auth: []AuthMethod{
+			KeyboardInteractive(answers.Challenge),
 		},
 	}
 
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err == nil {
-		c.Close()
+	if err := tryAuth(t, config); err == nil {
 		t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
 	}
 }
 
 // the mock server will only authenticate ssh-rsa keys
-func TestClientAuthInvalidPublicKey(t *testing.T) {
-	kc := new(keychain)
-
-	kc.add(dsaKey)
+func TestAuthMethodInvalidPublicKey(t *testing.T) {
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyring(kc),
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["dsa"]),
 		},
 	}
 
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err == nil {
-		c.Close()
+	if err := tryAuth(t, config); err == nil {
 		t.Fatalf("dsa private key should not have authenticated with rsa public key")
 	}
 }
 
 // the client should authenticate with the second key
-func TestClientAuthRSAandDSA(t *testing.T) {
-	kc := new(keychain)
-	kc.add(dsaKey)
-	kc.add(rsaKey)
+func TestAuthMethodRSAandDSA(t *testing.T) {
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyring(kc),
+		Auth: []AuthMethod{
+			PublicKeys(testSigners["dsa"], testSigners["rsa"]),
 		},
 	}
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err != nil {
+	if err := tryAuth(t, config); err != nil {
 		t.Fatalf("client could not authenticate with rsa key: %v", err)
 	}
-	c.Close()
 }
 
 func TestClientHMAC(t *testing.T) {
-	kc := new(keychain)
-	kc.add(rsaKey)
-	for _, mac := range DefaultMACOrder {
+	for _, mac := range supportedMACs {
 		config := &ClientConfig{
 			User: "testuser",
-			Auth: []ClientAuth{
-				ClientAuthKeyring(kc),
+			Auth: []AuthMethod{
+				PublicKeys(testSigners["rsa"]),
 			},
-			Crypto: CryptoConfig{
+			Config: Config{
 				MACs: []string{mac},
 			},
 		}
-		c, err := Dial("tcp", newMockAuthServer(t), config)
-		if err != nil {
+		if err := tryAuth(t, config); err != nil {
 			t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
 		}
-		c.Close()
 	}
 }
 
 // issue 4285.
 func TestClientUnsupportedCipher(t *testing.T) {
-	kc := new(keychain)
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyring(kc),
+		Auth: []AuthMethod{
+			PublicKeys(),
 		},
-		Crypto: CryptoConfig{
+		Config: Config{
 			Ciphers: []string{"aes128-cbc"}, // not currently supported
 		},
 	}
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err == nil {
+	if err := tryAuth(t, config); err == nil {
 		t.Errorf("expected no ciphers in common")
-		c.Close()
 	}
 }
 
 func TestClientUnsupportedKex(t *testing.T) {
-	kc := new(keychain)
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthKeyring(kc),
+		Auth: []AuthMethod{
+			PublicKeys(),
 		},
-		Crypto: CryptoConfig{
+		Config: Config{
 			KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
 		},
 	}
-	c, err := Dial("tcp", newMockAuthServer(t), config)
-	if err == nil || !strings.Contains(err.Error(), "no common algorithms") {
+	if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") {
 		t.Errorf("got %v, expected 'no common algorithms'", err)
 	}
-	if c != nil {
-		c.Close()
+}
+
+func TestClientLoginCert(t *testing.T) {
+	cert := &Certificate{
+		Key:         testPublicKeys["rsa"],
+		ValidBefore: CertTimeInfinity,
+		CertType:    UserCert,
+	}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+	if err != nil {
+		t.Fatalf("NewCertSigner: %v", err)
+	}
+
+	clientConfig := &ClientConfig{
+		User: "user",
+	}
+	clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
+
+	t.Log("should succeed")
+	if err := tryAuth(t, clientConfig); err != nil {
+		t.Errorf("cert login failed: %v", err)
+	}
+
+	t.Log("corrupted signature")
+	cert.Signature.Blob[0]++
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("cert login passed with corrupted sig")
+	}
+
+	t.Log("revoked")
+	cert.Serial = 666
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("revoked cert login succeeded")
+	}
+	cert.Serial = 1
+
+	t.Log("sign with wrong key")
+	cert.SignCert(rand.Reader, testSigners["dsa"])
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("cert login passed with non-authoritive key")
+	}
+
+	t.Log("host cert")
+	cert.CertType = HostCert
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("cert login passed with wrong type")
+	}
+	cert.CertType = UserCert
+
+	t.Log("principal specified")
+	cert.ValidPrincipals = []string{"user"}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err != nil {
+		t.Errorf("cert login failed: %v", err)
+	}
+
+	t.Log("wrong principal specified")
+	cert.ValidPrincipals = []string{"fred"}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("cert login passed with wrong principal")
+	}
+	cert.ValidPrincipals = nil
+
+	t.Log("added critical option")
+	cert.CriticalOptions = map[string]string{"root-access": "yes"}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("cert login passed with unrecognized critical option")
+	}
+
+	t.Log("allowed source address")
+	cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err != nil {
+		t.Errorf("cert login with source-address failed: %v", err)
+	}
+
+	t.Log("disallowed source address")
+	cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"}
+	cert.SignCert(rand.Reader, testSigners["ecdsa"])
+	if err := tryAuth(t, clientConfig); err == nil {
+		t.Errorf("cert login with source-address succeeded")
 	}
 }
diff --git a/ssh/client_test.go b/ssh/client_test.go
index f6c11b9..1fe790c 100644
--- a/ssh/client_test.go
+++ b/ssh/client_test.go
@@ -1,3 +1,7 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package ssh
 
 import (
@@ -7,6 +11,7 @@
 
 func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
 	clientConn, serverConn := net.Pipe()
+	defer clientConn.Close()
 	receivedVersion := make(chan string, 1)
 	go func() {
 		version, err := readVersion(serverConn)
@@ -17,7 +22,7 @@
 		}
 		serverConn.Close()
 	}()
-	Client(clientConn, config)
+	NewClientConn(clientConn, "", config)
 	actual := <-receivedVersion
 	if actual != expected {
 		t.Fatalf("got %s; want %s", actual, expected)
diff --git a/ssh/common.go b/ssh/common.go
index 4870e56..2fd7fd9 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -6,7 +6,9 @@
 
 import (
 	"crypto"
+	"crypto/rand"
 	"fmt"
+	"io"
 	"sync"
 
 	_ "crypto/sha1"
@@ -21,16 +23,39 @@
 	serviceSSH      = "ssh-connection"
 )
 
+// supportedCiphers specifies the supported ciphers in preference order.
+var supportedCiphers = []string{
+	"aes128-ctr", "aes192-ctr", "aes256-ctr",
+	"aes128-gcm@openssh.com",
+	"arcfour256", "arcfour128",
+}
+
+// supportedKexAlgos specifies the supported key-exchange algorithms in
+// preference order.
 var supportedKexAlgos = []string{
+	// P384 and P521 are not constant-time yet, but since we don't
+	// reuse ephemeral keys, using them for ECDH should be OK.
 	kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
 	kexAlgoDH14SHA1, kexAlgoDH1SHA1,
 }
 
+// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
+// of authenticating servers) in preference order.
 var supportedHostKeyAlgos = []string{
+	CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
+	CertAlgoECDSA384v01, CertAlgoECDSA521v01,
+
 	KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
 	KeyAlgoRSA, KeyAlgoDSA,
 }
 
+// supportedMACs specifies a default set of MAC algorithms in preference order.
+// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
+// because they have reached the end of their useful life.
+var supportedMACs = []string{
+	"hmac-sha1", "hmac-sha1-96",
+}
+
 var supportedCompressions = []string{compressionNone}
 
 // hashFuncs keeps the mapping of supported algorithms to their respective
@@ -48,23 +73,15 @@
 	CertAlgoECDSA521v01: crypto.SHA512,
 }
 
-// UnexpectedMessageError results when the SSH message that we received didn't
+// unexpectedMessageError results when the SSH message that we received didn't
 // match what we wanted.
-type UnexpectedMessageError struct {
-	expected, got uint8
+func unexpectedMessageError(expected, got uint8) error {
+	return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
 }
 
-func (u UnexpectedMessageError) Error() string {
-	return fmt.Sprintf("ssh: unexpected message type %d (expected %d)", u.got, u.expected)
-}
-
-// ParseError results from a malformed SSH message.
-type ParseError struct {
-	msgType uint8
-}
-
-func (p ParseError) Error() string {
-	return fmt.Sprintf("ssh: parse error in message type %d", p.msgType)
+// parseError results from a malformed SSH message.
+func parseError(tag uint8) error {
+	return fmt.Errorf("ssh: parse error in message type %d", tag)
 }
 
 func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
@@ -90,15 +107,17 @@
 	return
 }
 
+type directionAlgorithms struct {
+	Cipher      string
+	MAC         string
+	Compression string
+}
+
 type algorithms struct {
-	kex          string
-	hostKey      string
-	wCipher      string
-	rCipher      string
-	rMAC         string
-	wMAC         string
-	rCompression string
-	wCompression string
+	kex     string
+	hostKey string
+	w       directionAlgorithms
+	r       directionAlgorithms
 }
 
 func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
@@ -114,32 +133,32 @@
 		return
 	}
 
-	result.wCipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+	result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
 	if !ok {
 		return
 	}
 
-	result.rCipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+	result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
 	if !ok {
 		return
 	}
 
-	result.wMAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+	result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
 	if !ok {
 		return
 	}
 
-	result.rMAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+	result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
 	if !ok {
 		return
 	}
 
-	result.wCompression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+	result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
 	if !ok {
 		return
 	}
 
-	result.rCompression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+	result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
 	if !ok {
 		return
 	}
@@ -147,133 +166,87 @@
 	return result
 }
 
-// Cryptographic configuration common to both ServerConfig and ClientConfig.
-type CryptoConfig struct {
+// If rekeythreshold is too small, we can't make any progress sending
+// stuff.
+const minRekeyThreshold uint64 = 256
+
+// Config contains configuration data common to both ServerConfig and
+// ClientConfig.
+type Config struct {
+	// Rand provides the source of entropy for cryptographic
+	// primitives. If Rand is nil, the cryptographic random reader
+	// in package crypto/rand will be used.
+	Rand io.Reader
+
+	// The maximum number of bytes sent or received after which a
+	// new key is negotiated. It must be at least 256. If
+	// unspecified, 1 gigabyte is used.
+	RekeyThreshold uint64
+
 	// The allowed key exchanges algorithms. If unspecified then a
 	// default set of algorithms is used.
 	KeyExchanges []string
 
-	// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
-	// used.
+	// The allowed cipher algorithms. If unspecified then a sensible
+	// default is used.
 	Ciphers []string
 
-	// The allowed MAC algorithms. If unspecified then DefaultMACOrder is used.
+	// The allowed MAC algorithms. If unspecified then a sensible default
+	// is used.
 	MACs []string
 }
 
-func (c *CryptoConfig) ciphers() []string {
+// SetDefaults sets sensible values for unset fields in config. This is
+// exported for testing: Configs passed to SSH functions are copied and have
+// default values set automatically.
+func (c *Config) SetDefaults() {
+	if c.Rand == nil {
+		c.Rand = rand.Reader
+	}
 	if c.Ciphers == nil {
-		return DefaultCipherOrder
+		c.Ciphers = supportedCiphers
 	}
-	return c.Ciphers
-}
 
-func (c *CryptoConfig) kexes() []string {
 	if c.KeyExchanges == nil {
-		return defaultKeyExchangeOrder
+		c.KeyExchanges = supportedKexAlgos
 	}
-	return c.KeyExchanges
-}
 
-func (c *CryptoConfig) macs() []string {
 	if c.MACs == nil {
-		return DefaultMACOrder
+		c.MACs = supportedMACs
 	}
-	return c.MACs
-}
 
-// serialize a signed slice according to RFC 4254 6.6. The name should
-// be a key type name, rather than a cert type name.
-func serializeSignature(name string, sig []byte) []byte {
-	length := stringLength(len(name))
-	length += stringLength(len(sig))
-
-	ret := make([]byte, length)
-	r := marshalString(ret, []byte(name))
-	r = marshalString(r, sig)
-
-	return ret
-}
-
-// MarshalPublicKey serializes a supported key or certificate for use
-// by the SSH wire protocol. It can be used for comparison with the
-// pubkey argument of ServerConfig's PublicKeyCallback as well as for
-// generating an authorized_keys or host_keys file.
-func MarshalPublicKey(key PublicKey) []byte {
-	// See also RFC 4253 6.6.
-	algoname := key.PublicKeyAlgo()
-	blob := key.Marshal()
-
-	length := stringLength(len(algoname))
-	length += len(blob)
-	ret := make([]byte, length)
-	r := marshalString(ret, []byte(algoname))
-	copy(r, blob)
-	return ret
-}
-
-// pubAlgoToPrivAlgo returns the private key algorithm format name that
-// corresponds to a given public key algorithm format name.  For most
-// public keys, the private key algorithm name is the same.  For some
-// situations, such as openssh certificates, the private key algorithm and
-// public key algorithm names differ.  This accounts for those situations.
-func pubAlgoToPrivAlgo(pubAlgo string) string {
-	switch pubAlgo {
-	case CertAlgoRSAv01:
-		return KeyAlgoRSA
-	case CertAlgoDSAv01:
-		return KeyAlgoDSA
-	case CertAlgoECDSA256v01:
-		return KeyAlgoECDSA256
-	case CertAlgoECDSA384v01:
-		return KeyAlgoECDSA384
-	case CertAlgoECDSA521v01:
-		return KeyAlgoECDSA521
+	if c.RekeyThreshold == 0 {
+		// RFC 4253, section 9 suggests rekeying after 1G.
+		c.RekeyThreshold = 1 << 30
 	}
-	return pubAlgo
+	if c.RekeyThreshold < minRekeyThreshold {
+		c.RekeyThreshold = minRekeyThreshold
+	}
 }
 
 // buildDataSignedForAuth returns the data that is signed in order to prove
 // possession of a private key. See RFC 4252, section 7.
 func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
-	user := []byte(req.User)
-	service := []byte(req.Service)
-	method := []byte(req.Method)
-
-	length := stringLength(len(sessionId))
-	length += 1
-	length += stringLength(len(user))
-	length += stringLength(len(service))
-	length += stringLength(len(method))
-	length += 1
-	length += stringLength(len(algo))
-	length += stringLength(len(pubKey))
-
-	ret := make([]byte, length)
-	r := marshalString(ret, sessionId)
-	r[0] = msgUserAuthRequest
-	r = r[1:]
-	r = marshalString(r, user)
-	r = marshalString(r, service)
-	r = marshalString(r, method)
-	r[0] = 1
-	r = r[1:]
-	r = marshalString(r, algo)
-	r = marshalString(r, pubKey)
-	return ret
-}
-
-// safeString sanitises s according to RFC 4251, section 9.2.
-// All control characters except tab, carriage return and newline are
-// replaced by 0x20.
-func safeString(s string) string {
-	out := []byte(s)
-	for i, c := range out {
-		if c < 0x20 && c != 0xd && c != 0xa && c != 0x9 {
-			out[i] = 0x20
-		}
+	data := struct {
+		Session []byte
+		Type    byte
+		User    string
+		Service string
+		Method  string
+		Sign    bool
+		Algo    []byte
+		PubKey  []byte
+	}{
+		sessionId,
+		msgUserAuthRequest,
+		req.User,
+		req.Service,
+		req.Method,
+		true,
+		algo,
+		pubKey,
 	}
-	return string(out)
+	return Marshal(data)
 }
 
 func appendU16(buf []byte, n uint16) []byte {
@@ -284,6 +257,12 @@
 	return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
 }
 
+func appendU64(buf []byte, n uint64) []byte {
+	return append(buf,
+		byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
+		byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
+}
+
 func appendInt(buf []byte, n int) []byte {
 	return appendU32(buf, uint32(n))
 }
@@ -296,11 +275,9 @@
 
 func appendBool(buf []byte, b bool) []byte {
 	if b {
-		buf = append(buf, 1)
-	} else {
-		buf = append(buf, 0)
+		return append(buf, 1)
 	}
-	return buf
+	return append(buf, 0)
 }
 
 // newCond is a helper to hide the fact that there is no usable zero
@@ -311,7 +288,9 @@
 // wishing to write to a channel.
 type window struct {
 	*sync.Cond
-	win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
+	win          uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
+	writeWaiters int
+	closed       bool
 }
 
 // add adds win to the amount of window available
@@ -335,18 +314,44 @@
 	return true
 }
 
+// close sets the window to closed, so all reservations fail
+// immediately.
+func (w *window) close() {
+	w.L.Lock()
+	w.closed = true
+	w.Broadcast()
+	w.L.Unlock()
+}
+
 // reserve reserves win from the available window capacity.
 // If no capacity remains, reserve will block. reserve may
 // return less than requested.
-func (w *window) reserve(win uint32) uint32 {
+func (w *window) reserve(win uint32) (uint32, error) {
+	var err error
 	w.L.Lock()
-	for w.win == 0 {
+	w.writeWaiters++
+	w.Broadcast()
+	for w.win == 0 && !w.closed {
 		w.Wait()
 	}
+	w.writeWaiters--
 	if w.win < win {
 		win = w.win
 	}
 	w.win -= win
+	if w.closed {
+		err = io.EOF
+	}
 	w.L.Unlock()
-	return win
+	return win, err
+}
+
+// waitWriterBlocked waits until some goroutine is blocked for further
+// writes. It is used in tests only.
+func (w *window) waitWriterBlocked() {
+	w.Cond.L.Lock()
+	for w.writeWaiters == 0 {
+		w.Cond.Wait()
+	}
+	w.Cond.L.Unlock()
 }
diff --git a/ssh/connection.go b/ssh/connection.go
new file mode 100644
index 0000000..93551e2
--- /dev/null
+++ b/ssh/connection.go
@@ -0,0 +1,144 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"fmt"
+	"net"
+)
+
+// OpenChannelError is returned if the other side rejects an
+// OpenChannel request.
+type OpenChannelError struct {
+	Reason  RejectionReason
+	Message string
+}
+
+func (e *OpenChannelError) Error() string {
+	return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
+}
+
+// ConnMetadata holds metadata for the connection.
+type ConnMetadata interface {
+	// User returns the user ID for this connection.
+	// It is empty if no authentication is used.
+	User() string
+
+	// SessionID returns the sesson hash, also denoted by H.
+	SessionID() []byte
+
+	// ClientVersion returns the client's version string as hashed
+	// into the session ID.
+	ClientVersion() []byte
+
+	// ServerVersion returns the client's version string as hashed
+	// into the session ID.
+	ServerVersion() []byte
+
+	// RemoteAddr returns the remote address for this connection.
+	RemoteAddr() net.Addr
+
+	// LocalAddr returns the local address for this connection.
+	LocalAddr() net.Addr
+}
+
+// Conn represents an SSH connection for both server and client roles.
+// Conn is the basis for implementing an application layer, such
+// as ClientConn, which implements the traditional shell access for
+// clients.
+type Conn interface {
+	ConnMetadata
+
+	// SendRequest sends a global request, and returns the
+	// reply. If wantReply is true, it returns the response status
+	// and payload. See also RFC4254, section 4.
+	SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
+
+	// OpenChannel tries to open an channel. If the request is
+	// rejected, it returns *OpenChannelError. On success it returns
+	// the SSH Channel and a Go channel for incoming, out-of-band
+	// requests. The Go channel must be serviced, or the
+	// connection will hang.
+	OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
+
+	// Close closes the underlying network connection
+	Close() error
+
+	// Wait blocks until the connection has shut down, and returns the
+	// error causing the shutdown.
+	Wait() error
+
+	// TODO(hanwen): consider exposing:
+	//   RequestKeyChange
+	//   Disconnect
+}
+
+// DiscardRequests consumes and rejects all requests from the
+// passed-in channel.
+func DiscardRequests(in <-chan *Request) {
+	for req := range in {
+		if req.WantReply {
+			req.Reply(false, nil)
+		}
+	}
+}
+
+// A connection represents an incoming connection.
+type connection struct {
+	transport *handshakeTransport
+	sshConn
+
+	// The connection protocol.
+	*mux
+}
+
+func (c *connection) Close() error {
+	return c.sshConn.conn.Close()
+}
+
+// sshconn provides net.Conn metadata, but disallows direct reads and
+// writes.
+type sshConn struct {
+	conn net.Conn
+
+	user          string
+	sessionID     []byte
+	clientVersion []byte
+	serverVersion []byte
+}
+
+func dup(src []byte) []byte {
+	dst := make([]byte, len(src))
+	copy(dst, src)
+	return dst
+}
+
+func (c *sshConn) User() string {
+	return c.user
+}
+
+func (c *sshConn) RemoteAddr() net.Addr {
+	return c.conn.RemoteAddr()
+}
+
+func (c *sshConn) Close() error {
+	return c.conn.Close()
+}
+
+func (c *sshConn) LocalAddr() net.Addr {
+	return c.conn.LocalAddr()
+}
+
+func (c *sshConn) SessionID() []byte {
+	return dup(c.sessionID)
+}
+
+func (c *sshConn) ClientVersion() []byte {
+	return dup(c.clientVersion)
+}
+
+func (c *sshConn) ServerVersion() []byte {
+	return dup(c.serverVersion)
+}
diff --git a/ssh/doc.go b/ssh/doc.go
index 22ff338..d4d16f0 100644
--- a/ssh/doc.go
+++ b/ssh/doc.go
@@ -13,7 +13,6 @@
 
 References:
   [PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
-  [PROTOCOL.agent]:    http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
   [SSH-PARAMETERS]:    http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
 */
 package ssh
diff --git a/ssh/example_test.go b/ssh/example_test.go
index a88a677..d9d6a54 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -9,17 +9,23 @@
 	"fmt"
 	"io/ioutil"
 	"log"
+	"net"
 	"net/http"
 
 	"code.google.com/p/go.crypto/ssh/terminal"
 )
 
-func ExampleListen() {
+func ExampleNewServerConn() {
 	// An SSH server is represented by a ServerConfig, which holds
 	// certificate details and handles authentication of ServerConns.
 	config := &ServerConfig{
-		PasswordCallback: func(conn *ServerConn, user, pass string) bool {
-			return user == "testuser" && pass == "tiger"
+		PasswordCallback: func(c ConnMetadata, pass []byte) (*Permissions, error) {
+			// Should use constant-time compare (or better, salt+hash) in
+			// a production setting.
+			if c.User() == "testuser" && string(pass) == "tiger" {
+				return nil, nil
+			}
+			return nil, fmt.Errorf("password rejected for %q", c.User())
 		},
 	}
 
@@ -37,50 +43,65 @@
 
 	// Once a ServerConfig has been configured, connections can be
 	// accepted.
-	listener, err := Listen("tcp", "0.0.0.0:2022", config)
+	listener, err := net.Listen("tcp", "0.0.0.0:2022")
 	if err != nil {
 		panic("failed to listen for connection")
 	}
-	sConn, err := listener.Accept()
+	nConn, err := listener.Accept()
 	if err != nil {
 		panic("failed to accept incoming connection")
 	}
-	if err := sConn.Handshake(); err != nil {
+
+	// Before use, a handshake must be performed on the incoming
+	// net.Conn.
+	_, chans, reqs, err := NewServerConn(nConn, config)
+	if err != nil {
 		panic("failed to handshake")
 	}
+	// The incoming Request channel must be serviced.
+	go DiscardRequests(reqs)
 
-	// A ServerConn multiplexes several channels, which must
-	// themselves be Accepted.
-	for {
-		// Accept reads from the connection, demultiplexes packets
-		// to their corresponding channels and returns when a new
-		// channel request is seen. Some goroutine must always be
-		// calling Accept; otherwise no messages will be forwarded
-		// to the channels.
-		channel, err := sConn.Accept()
-		if err != nil {
-			panic("error from Accept")
-		}
-
+	// Service the incoming Channel channel.
+	for newChannel := range chans {
 		// Channels have a type, depending on the application level
 		// protocol intended. In the case of a shell, the type is
 		// "session" and ServerShell may be used to present a simple
 		// terminal interface.
-		if channel.ChannelType() != "session" {
-			channel.Reject(UnknownChannelType, "unknown channel type")
+		if newChannel.ChannelType() != "session" {
+			newChannel.Reject(UnknownChannelType, "unknown channel type")
 			continue
 		}
-		channel.Accept()
+		channel, requests, err := newChannel.Accept()
+		if err != nil {
+			panic("could not accept channel.")
+		}
+
+		// Sessions have out-of-band requests such as "shell",
+		// "pty-req" and "env".  Here we handle only the
+		// "shell" request.
+		go func(in <-chan *Request) {
+			for req := range in {
+				ok := false
+				switch req.Type {
+				case "shell":
+					ok = true
+					if len(req.Payload) > 0 {
+						// We don't accept any
+						// commands, only the
+						// default shell.
+						ok = false
+					}
+				}
+				req.Reply(ok, nil)
+			}
+		}(requests)
 
 		term := terminal.NewTerminal(channel, "> ")
-		serverTerm := &ServerTerminal{
-			Term:    term,
-			Channel: channel,
-		}
+
 		go func() {
 			defer channel.Close()
 			for {
-				line, err := serverTerm.ReadLine()
+				line, err := term.ReadLine()
 				if err != nil {
 					break
 				}
@@ -95,13 +116,11 @@
 	// the "password" authentication method is supported.
 	//
 	// To authenticate with the remote server you must pass at least one
-	// implementation of ClientAuth via the Auth field in ClientConfig.
+	// implementation of AuthMethod via the Auth field in ClientConfig.
 	config := &ClientConfig{
 		User: "username",
-		Auth: []ClientAuth{
-			// ClientAuthPassword wraps a ClientPassword implementation
-			// in a type that implements ClientAuth.
-			ClientAuthPassword(password("yourpassword")),
+		Auth: []AuthMethod{
+			Password("yourpassword"),
 		},
 	}
 	client, err := Dial("tcp", "yourserver.com:22", config)
@@ -127,11 +146,11 @@
 	fmt.Println(b.String())
 }
 
-func ExampleClientConn_Listen() {
+func ExampleClient_Listen() {
 	config := &ClientConfig{
 		User: "username",
-		Auth: []ClientAuth{
-			ClientAuthPassword(password("password")),
+		Auth: []AuthMethod{
+			Password("password"),
 		},
 	}
 	// Dial your ssh server.
@@ -158,8 +177,8 @@
 	// Create client config
 	config := &ClientConfig{
 		User: "username",
-		Auth: []ClientAuth{
-			ClientAuthPassword(password("password")),
+		Auth: []AuthMethod{
+			Password("password"),
 		},
 	}
 	// Connect to ssh server
diff --git a/ssh/handshake.go b/ssh/handshake.go
new file mode 100644
index 0000000..a1e2c23
--- /dev/null
+++ b/ssh/handshake.go
@@ -0,0 +1,393 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"crypto/rand"
+	"errors"
+	"fmt"
+	"io"
+	"log"
+	"net"
+	"sync"
+)
+
+// debugHandshake, if set, prints messages sent and received.  Key
+// exchange messages are printed as if DH were used, so the debug
+// messages are wrong when using ECDH.
+const debugHandshake = false
+
+// keyingTransport is a packet based transport that supports key
+// changes. It need not be thread-safe. It should pass through
+// msgNewKeys in both directions.
+type keyingTransport interface {
+	packetConn
+
+	// prepareKeyChange sets up a key change. The key change for a
+	// direction will be effected if a msgNewKeys message is sent
+	// or received.
+	prepareKeyChange(*algorithms, *kexResult) error
+
+	// getSessionID returns the session ID. prepareKeyChange must
+	// have been called once.
+	getSessionID() []byte
+}
+
+// rekeyingTransport is the interface of handshakeTransport that we
+// (internally) expose to ClientConn and ServerConn.
+type rekeyingTransport interface {
+	packetConn
+
+	// requestKeyChange asks the remote side to change keys. All
+	// writes are blocked until the key change succeeds, which is
+	// signaled by reading a msgNewKeys.
+	requestKeyChange() error
+
+	// getSessionID returns the session ID. This is only valid
+	// after the first key change has completed.
+	getSessionID() []byte
+}
+
+// handshakeTransport implements rekeying on top of a keyingTransport
+// and offers a thread-safe writePacket() interface.
+type handshakeTransport struct {
+	conn   keyingTransport
+	config *Config
+
+	serverVersion []byte
+	clientVersion []byte
+
+	hostKeys []Signer // If hostKeys are given, we are the server.
+
+	// On read error, incoming is closed, and readError is set.
+	incoming  chan []byte
+	readError error
+
+	// data for host key checking
+	hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+	dialAddress     string
+	remoteAddr      net.Addr
+
+	readSinceKex uint64
+
+	// Protects the writing side of the connection
+	mu              sync.Mutex
+	cond            *sync.Cond
+	sentInitPacket  []byte
+	sentInitMsg     *kexInitMsg
+	writtenSinceKex uint64
+	writeError      error
+}
+
+func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
+	t := &handshakeTransport{
+		conn:          conn,
+		serverVersion: serverVersion,
+		clientVersion: clientVersion,
+		incoming:      make(chan []byte, 16),
+		config:        config,
+	}
+	t.cond = sync.NewCond(&t.mu)
+	return t
+}
+
+func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
+	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
+	t.dialAddress = dialAddr
+	t.remoteAddr = addr
+	t.hostKeyCallback = config.HostKeyCallback
+	go t.readLoop()
+	return t
+}
+
+func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
+	t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
+	t.hostKeys = config.hostKeys
+	go t.readLoop()
+	return t
+}
+
+func (t *handshakeTransport) getSessionID() []byte {
+	return t.conn.getSessionID()
+}
+
+func (t *handshakeTransport) id() string {
+	if len(t.hostKeys) > 0 {
+		return "server"
+	}
+	return "client"
+}
+
+func (t *handshakeTransport) readPacket() ([]byte, error) {
+	p, ok := <-t.incoming
+	if !ok {
+		return nil, t.readError
+	}
+	return p, nil
+}
+
+func (t *handshakeTransport) readLoop() {
+	for {
+		p, err := t.readOnePacket()
+		if err != nil {
+			t.readError = err
+			close(t.incoming)
+			break
+		}
+		if p[0] == msgIgnore || p[0] == msgDebug {
+			continue
+		}
+		t.incoming <- p
+	}
+}
+
+func (t *handshakeTransport) readOnePacket() ([]byte, error) {
+	if t.readSinceKex > t.config.RekeyThreshold {
+		if err := t.requestKeyChange(); err != nil {
+			return nil, err
+		}
+	}
+
+	p, err := t.conn.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	t.readSinceKex += uint64(len(p))
+	if debugHandshake {
+		msg, err := decode(p)
+		log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
+	}
+	if p[0] != msgKexInit {
+		return p, nil
+	}
+	err = t.enterKeyExchange(p)
+
+	t.mu.Lock()
+	if err != nil {
+		// drop connection
+		t.conn.Close()
+		t.writeError = err
+	}
+
+	if debugHandshake {
+		log.Printf("%s exited key exchange, err %v", t.id(), err)
+	}
+
+	// Unblock writers.
+	t.sentInitMsg = nil
+	t.sentInitPacket = nil
+	t.cond.Broadcast()
+	t.writtenSinceKex = 0
+	t.mu.Unlock()
+
+	if err != nil {
+		return nil, err
+	}
+
+	t.readSinceKex = 0
+	return []byte{msgNewKeys}, nil
+}
+
+// sendKexInit sends a key change message, and returns the message
+// that was sent. After initiating the key change, all writes will be
+// blocked until the change is done, and a failed key change will
+// close the underlying transport. This function is safe for
+// concurrent use by multiple goroutines.
+func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	return t.sendKexInitLocked()
+}
+
+func (t *handshakeTransport) requestKeyChange() error {
+	_, _, err := t.sendKexInit()
+	return err
+}
+
+// sendKexInitLocked sends a key change message. t.mu must be locked
+// while this happens.
+func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
+	// kexInits may be sent either in response to the other side,
+	// or because our side wants to initiate a key change, so we
+	// may have already sent a kexInit. In that case, don't send a
+	// second kexInit.
+	if t.sentInitMsg != nil {
+		return t.sentInitMsg, t.sentInitPacket, nil
+	}
+	msg := &kexInitMsg{
+		KexAlgos:                t.config.KeyExchanges,
+		CiphersClientServer:     t.config.Ciphers,
+		CiphersServerClient:     t.config.Ciphers,
+		MACsClientServer:        t.config.MACs,
+		MACsServerClient:        t.config.MACs,
+		CompressionClientServer: supportedCompressions,
+		CompressionServerClient: supportedCompressions,
+	}
+	io.ReadFull(rand.Reader, msg.Cookie[:])
+
+	if len(t.hostKeys) > 0 {
+		for _, k := range t.hostKeys {
+			msg.ServerHostKeyAlgos = append(
+				msg.ServerHostKeyAlgos, k.PublicKey().Type())
+		}
+	} else {
+		msg.ServerHostKeyAlgos = supportedHostKeyAlgos
+	}
+	packet := Marshal(msg)
+
+	// writePacket destroys the contents, so save a copy.
+	packetCopy := make([]byte, len(packet))
+	copy(packetCopy, packet)
+
+	if err := t.conn.writePacket(packetCopy); err != nil {
+		return nil, nil, err
+	}
+
+	t.sentInitMsg = msg
+	t.sentInitPacket = packet
+	return msg, packet, nil
+}
+
+func (t *handshakeTransport) writePacket(p []byte) error {
+	t.mu.Lock()
+	if t.writtenSinceKex > t.config.RekeyThreshold {
+		t.sendKexInitLocked()
+	}
+	for t.sentInitMsg != nil {
+		t.cond.Wait()
+	}
+	if t.writeError != nil {
+		return t.writeError
+	}
+	t.writtenSinceKex += uint64(len(p))
+
+	var err error
+	switch p[0] {
+	case msgKexInit:
+		err = errors.New("ssh: only handshakeTransport can send kexInit")
+	case msgNewKeys:
+		err = errors.New("ssh: only handshakeTransport can send newKeys")
+	default:
+		err = t.conn.writePacket(p)
+	}
+	t.mu.Unlock()
+	return err
+}
+
+func (t *handshakeTransport) Close() error {
+	return t.conn.Close()
+}
+
+// enterKeyExchange runs the key exchange.
+func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
+	if debugHandshake {
+		log.Printf("%s entered key exchange", t.id())
+	}
+	myInit, myInitPacket, err := t.sendKexInit()
+	if err != nil {
+		return err
+	}
+
+	otherInit := &kexInitMsg{}
+	if err := Unmarshal(otherInitPacket, otherInit); err != nil {
+		return err
+	}
+
+	magics := handshakeMagics{
+		clientVersion: t.clientVersion,
+		serverVersion: t.serverVersion,
+		clientKexInit: otherInitPacket,
+		serverKexInit: myInitPacket,
+	}
+
+	clientInit := otherInit
+	serverInit := myInit
+	if len(t.hostKeys) == 0 {
+		clientInit = myInit
+		serverInit = otherInit
+
+		magics.clientKexInit = myInitPacket
+		magics.serverKexInit = otherInitPacket
+	}
+
+	algs := findAgreedAlgorithms(clientInit, serverInit)
+	if algs == nil {
+		return errors.New("ssh: no common algorithms")
+	}
+
+	// We don't send FirstKexFollows, but we handle receiving it.
+	if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
+		// other side sent a kex message for the wrong algorithm,
+		// which we have to ignore.
+		if _, err := t.conn.readPacket(); err != nil {
+			return err
+		}
+	}
+
+	kex, ok := kexAlgoMap[algs.kex]
+	if !ok {
+		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
+	}
+
+	var result *kexResult
+	if len(t.hostKeys) > 0 {
+		result, err = t.server(kex, algs, &magics)
+	} else {
+		result, err = t.client(kex, algs, &magics)
+	}
+
+	if err != nil {
+		return err
+	}
+
+	t.conn.prepareKeyChange(algs, result)
+	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
+		return err
+	}
+	if packet, err := t.conn.readPacket(); err != nil {
+		return err
+	} else if packet[0] != msgNewKeys {
+		return unexpectedMessageError(msgNewKeys, packet[0])
+	}
+	return nil
+}
+
+func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
+	var hostKey Signer
+	for _, k := range t.hostKeys {
+		if algs.hostKey == k.PublicKey().Type() {
+			hostKey = k
+		}
+	}
+
+	r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
+	return r, err
+}
+
+func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
+	result, err := kex.Client(t.conn, t.config.Rand, magics)
+	if err != nil {
+		return nil, err
+	}
+
+	hostKey, err := ParsePublicKey(result.HostKey)
+	if err != nil {
+		return nil, err
+	}
+
+	if err := verifyHostKeySignature(hostKey, result); err != nil {
+		return nil, err
+	}
+
+	if t.hostKeyCallback != nil {
+		err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	return result, nil
+}
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
new file mode 100644
index 0000000..613c498
--- /dev/null
+++ b/ssh/handshake_test.go
@@ -0,0 +1,311 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"bytes"
+	"crypto/rand"
+	"fmt"
+	"net"
+	"testing"
+)
+
+type testChecker struct {
+	calls []string
+}
+
+func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+	if dialAddr == "bad" {
+		return fmt.Errorf("dialAddr is bad")
+	}
+
+	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
+		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
+	}
+
+	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
+
+	return nil
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return nil, nil, err
+	}
+	defer listener.Close()
+	c1, err := net.Dial("tcp", listener.Addr().String())
+	if err != nil {
+		return nil, nil, err
+	}
+
+	c2, err := listener.Accept()
+	if err != nil {
+		c1.Close()
+		return nil, nil, err
+	}
+
+	return c1, c2, nil
+}
+
+func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
+	a, b, err := netPipe()
+	if err != nil {
+		return nil, nil, err
+	}
+
+	trC := newTransport(a, rand.Reader, true)
+	trS := newTransport(b, rand.Reader, false)
+	clientConf.SetDefaults()
+
+	v := []byte("version")
+	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
+
+	serverConf := &ServerConfig{}
+	serverConf.AddHostKey(testSigners["ecdsa"])
+	serverConf.SetDefaults()
+	server = newServerTransport(trS, v, v, serverConf)
+
+	return client, server, nil
+}
+
+func TestHandshakeBasic(t *testing.T) {
+	checker := &testChecker{}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+
+	defer trC.Close()
+	defer trS.Close()
+
+	go func() {
+		// Client writes a bunch of stuff, and does a key
+		// change in the middle. This should not confuse the
+		// handshake in progress
+		for i := 0; i < 10; i++ {
+			p := []byte{msgRequestSuccess, byte(i)}
+			if err := trC.writePacket(p); err != nil {
+				t.Fatalf("sendPacket: %v", err)
+			}
+			if i == 5 {
+				// halfway through, we request a key change.
+				_, _, err := trC.sendKexInit()
+				if err != nil {
+					t.Fatalf("sendKexInit: %v", err)
+				}
+			}
+		}
+		trC.Close()
+	}()
+
+	// Server checks that client messages come in cleanly
+	i := 0
+	for {
+		p, err := trS.readPacket()
+		if err != nil {
+			break
+		}
+		if p[0] == msgNewKeys {
+			continue
+		}
+		want := []byte{msgRequestSuccess, byte(i)}
+		if bytes.Compare(p, want) != 0 {
+			t.Errorf("message %d: got %q, want %q", i, p, want)
+		}
+		i++
+	}
+	if i != 10 {
+		t.Errorf("received %d messages, want 10.", i)
+	}
+
+	// If all went well, we registered exactly 1 key change.
+	if len(checker.calls) != 1 {
+		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
+	}
+
+	pub := testSigners["ecdsa"].PublicKey()
+	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
+	if want != checker.calls[0] {
+		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
+	}
+}
+
+func TestHandshakeError(t *testing.T) {
+	checker := &testChecker{}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	defer trC.Close()
+	defer trS.Close()
+
+	// send a packet
+	packet := []byte{msgRequestSuccess, 42}
+	if err := trC.writePacket(packet); err != nil {
+		t.Errorf("writePacket: %v", err)
+	}
+
+	// Now request a key change.
+	_, _, err = trC.sendKexInit()
+	if err != nil {
+		t.Errorf("sendKexInit: %v", err)
+	}
+
+	// the key change will fail, and afterwards we can't write.
+	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
+		t.Errorf("writePacket after botched rekey succeeded.")
+	}
+
+	readback, err := trS.readPacket()
+	if err != nil {
+		t.Fatalf("server closed too soon: %v", err)
+	}
+	if bytes.Compare(readback, packet) != 0 {
+		t.Errorf("got %q want %q", readback, packet)
+	}
+	readback, err = trS.readPacket()
+	if err == nil {
+		t.Errorf("got a message %q after failed key change", readback)
+	}
+}
+
+func TestHandshakeTwice(t *testing.T) {
+	checker := &testChecker{}
+	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+
+	defer trC.Close()
+	defer trS.Close()
+
+	// send a packet
+	packet := make([]byte, 5)
+	packet[0] = msgRequestSuccess
+	if err := trC.writePacket(packet); err != nil {
+		t.Errorf("writePacket: %v", err)
+	}
+
+	// Now request a key change.
+	_, _, err = trC.sendKexInit()
+	if err != nil {
+		t.Errorf("sendKexInit: %v", err)
+	}
+
+	// Send another packet. Use a fresh one, since writePacket destroys.
+	packet = make([]byte, 5)
+	packet[0] = msgRequestSuccess
+	if err := trC.writePacket(packet); err != nil {
+		t.Errorf("writePacket: %v", err)
+	}
+
+	// 2nd key change.
+	_, _, err = trC.sendKexInit()
+	if err != nil {
+		t.Errorf("sendKexInit: %v", err)
+	}
+
+	packet = make([]byte, 5)
+	packet[0] = msgRequestSuccess
+	if err := trC.writePacket(packet); err != nil {
+		t.Errorf("writePacket: %v", err)
+	}
+
+	packet = make([]byte, 5)
+	packet[0] = msgRequestSuccess
+	for i := 0; i < 5; i++ {
+		msg, err := trS.readPacket()
+		if err != nil {
+			t.Fatalf("server closed too soon: %v", err)
+		}
+		if msg[0] == msgNewKeys {
+			continue
+		}
+
+		if bytes.Compare(msg, packet) != 0 {
+			t.Errorf("packet %d: got %q want %q", i, msg, packet)
+		}
+	}
+	if len(checker.calls) != 2 {
+		t.Errorf("got %d key changes, want 2", len(checker.calls))
+	}
+}
+
+func TestHandshakeAutoRekeyWrite(t *testing.T) {
+	checker := &testChecker{}
+	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+	clientConf.RekeyThreshold = 500
+	trC, trS, err := handshakePair(clientConf, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	defer trC.Close()
+	defer trS.Close()
+
+	for i := 0; i < 5; i++ {
+		packet := make([]byte, 251)
+		packet[0] = msgRequestSuccess
+		if err := trC.writePacket(packet); err != nil {
+			t.Errorf("writePacket: %v", err)
+		}
+	}
+
+	j := 0
+	for ; j < 5; j++ {
+		_, err := trS.readPacket()
+		if err != nil {
+			break
+		}
+	}
+
+	if j != 5 {
+		t.Errorf("got %d, want 5 messages", j)
+	}
+
+	if len(checker.calls) != 2 {
+		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
+	}
+}
+
+type syncChecker struct {
+	called chan int
+}
+
+func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+	t.called <- 1
+	return nil
+}
+
+func TestHandshakeAutoRekeyRead(t *testing.T) {
+	sync := &syncChecker{make(chan int, 2)}
+	clientConf := &ClientConfig{
+		HostKeyCallback: sync.Check,
+	}
+	clientConf.RekeyThreshold = 500
+
+	trC, trS, err := handshakePair(clientConf, "addr")
+	if err != nil {
+		t.Fatalf("handshakePair: %v", err)
+	}
+	defer trC.Close()
+	defer trS.Close()
+
+	packet := make([]byte, 501)
+	packet[0] = msgRequestSuccess
+	if err := trS.writePacket(packet); err != nil {
+		t.Fatalf("writePacket: %v", err)
+	}
+	// While we read out the packet, a key change will be
+	// initiated.
+	if _, err := trC.readPacket(); err != nil {
+		t.Fatalf("readPacket(client): %v", err)
+	}
+
+	<-sync.called
+}
diff --git a/ssh/kex.go b/ssh/kex.go
index d2e3b70..6a835c7 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -30,10 +30,10 @@
 	// Shared secret. See also RFC 4253, section 8.
 	K []byte
 
-	// Host key as hashed into H
+	// Host key as hashed into H.
 	HostKey []byte
 
-	// Signature of H
+	// Signature of H.
 	Signature []byte
 
 	// A cryptographic hash function that matches the security
@@ -94,7 +94,7 @@
 	kexDHInit := kexDHInitMsg{
 		X: X,
 	}
-	if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
+	if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
 		return nil, err
 	}
 
@@ -104,7 +104,7 @@
 	}
 
 	var kexDHReply kexDHReplyMsg
-	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
+	if err = Unmarshal(packet, &kexDHReply); err != nil {
 		return nil, err
 	}
 
@@ -138,7 +138,7 @@
 		return
 	}
 	var kexDHInit kexDHInitMsg
-	if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+	if err = Unmarshal(packet, &kexDHInit); err != nil {
 		return
 	}
 
@@ -153,7 +153,7 @@
 		return nil, err
 	}
 
-	hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+	hostKeyBytes := priv.PublicKey().Marshal()
 
 	h := hashFunc.New()
 	magics.write(h)
@@ -179,7 +179,7 @@
 		Y:         Y,
 		Signature: sig,
 	}
-	packet = marshal(msgKexDHReply, kexDHReply)
+	packet = Marshal(&kexDHReply)
 
 	err = c.writePacket(packet)
 	return &kexResult{
@@ -207,7 +207,7 @@
 		ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
 	}
 
-	serialized := marshal(msgKexECDHInit, kexInit)
+	serialized := Marshal(&kexInit)
 	if err := c.writePacket(serialized); err != nil {
 		return nil, err
 	}
@@ -218,7 +218,7 @@
 	}
 
 	var reply kexECDHReplyMsg
-	if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
+	if err = Unmarshal(packet, &reply); err != nil {
 		return nil, err
 	}
 
@@ -297,7 +297,7 @@
 	}
 
 	var kexECDHInit kexECDHInitMsg
-	if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
+	if err = Unmarshal(packet, &kexECDHInit); err != nil {
 		return nil, err
 	}
 
@@ -314,7 +314,7 @@
 		return nil, err
 	}
 
-	hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+	hostKeyBytes := priv.PublicKey().Marshal()
 
 	serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
 
@@ -346,7 +346,7 @@
 		Signature:       sig,
 	}
 
-	serialized := marshal(msgKexECDHReply, reply)
+	serialized := Marshal(&reply)
 	if err := c.writePacket(serialized); err != nil {
 		return nil, err
 	}
diff --git a/ssh/kex_test.go b/ssh/kex_test.go
index 1e931a3..0db5f9b 100644
--- a/ssh/kex_test.go
+++ b/ssh/kex_test.go
@@ -29,7 +29,7 @@
 			c <- kexResultErr{r, e}
 		}()
 		go func() {
-			r, e := kex.Server(b, rand.Reader, &magics, ecdsaKey)
+			r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
 			s <- kexResultErr{r, e}
 		}()
 
diff --git a/ssh/keys.go b/ssh/keys.go
index b41fefc..e8af511 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -33,7 +33,7 @@
 
 // parsePubKey parses a public key of the given algorithm.
 // Use ParsePublicKey for keys with prepended algorithm.
-func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, ok bool) {
+func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
 	switch algo {
 	case KeyAlgoRSA:
 		return parseRSA(in)
@@ -42,15 +42,19 @@
 	case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
 		return parseECDSA(in)
 	case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
-		return parseOpenSSHCertV01(in, algo)
+		cert, err := parseCert(in, certToPrivAlgo(algo))
+		if err != nil {
+			return nil, nil, err
+		}
+		return cert, nil, nil
 	}
-	return nil, nil, false
+	return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err)
 }
 
 // parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
 // (see sshd(8) manual page) once the options and key type fields have been
 // removed.
-func parseAuthorizedKey(in []byte) (out PublicKey, comment string, ok bool) {
+func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) {
 	in = bytes.TrimSpace(in)
 
 	i := bytes.IndexAny(in, " \t")
@@ -62,20 +66,20 @@
 	key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key)))
 	n, err := base64.StdEncoding.Decode(key, base64Key)
 	if err != nil {
-		return
+		return nil, "", err
 	}
 	key = key[:n]
-	out, _, ok = ParsePublicKey(key)
-	if !ok {
-		return nil, "", false
+	out, err = ParsePublicKey(key)
+	if err != nil {
+		return nil, "", err
 	}
 	comment = string(bytes.TrimSpace(in[i:]))
-	return
+	return out, comment, nil
 }
 
 // ParseAuthorizedKeys parses a public key from an authorized_keys
 // file used in OpenSSH according to the sshd(8) manual page.
-func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, ok bool) {
+func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
 	for len(in) > 0 {
 		end := bytes.IndexByte(in, '\n')
 		if end != -1 {
@@ -102,8 +106,8 @@
 			continue
 		}
 
-		if out, comment, ok = parseAuthorizedKey(in[i:]); ok {
-			return
+		if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
+			return out, comment, options, rest, nil
 		}
 
 		// No key type recognised. Maybe there's an options field at
@@ -143,38 +147,42 @@
 			continue
 		}
 
-		if out, comment, ok = parseAuthorizedKey(in[i:]); ok {
+		if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
 			options = candidateOptions
-			return
+			return out, comment, options, rest, nil
 		}
 
 		in = rest
 		continue
 	}
 
-	return
+	return nil, "", nil, nil, errors.New("ssh: no key found")
 }
 
 // ParsePublicKey parses an SSH public key formatted for use in
 // the SSH wire protocol according to RFC 4253, section 6.6.
-func ParsePublicKey(in []byte) (out PublicKey, rest []byte, ok bool) {
+func ParsePublicKey(in []byte) (out PublicKey, err error) {
 	algo, in, ok := parseString(in)
 	if !ok {
-		return
+		return nil, errShortRead
+	}
+	var rest []byte
+	out, rest, err = parsePubKey(in, string(algo))
+	if len(rest) > 0 {
+		return nil, errors.New("ssh: trailing junk in public key")
 	}
 
-	return parsePubKey(in, string(algo))
+	return out, err
 }
 
-// MarshalAuthorizedKey returns a byte stream suitable for inclusion
-// in an OpenSSH authorized_keys file following the format specified
-// in the sshd(8) manual page.
+// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
+// authorized_keys file. The return value ends with newline.
 func MarshalAuthorizedKey(key PublicKey) []byte {
 	b := &bytes.Buffer{}
-	b.WriteString(key.PublicKeyAlgo())
+	b.WriteString(key.Type())
 	b.WriteByte(' ')
 	e := base64.NewEncoder(base64.StdEncoding, b)
-	e.Write(MarshalPublicKey(key))
+	e.Write(key.Marshal())
 	e.Close()
 	b.WriteByte('\n')
 	return b.Bytes()
@@ -182,84 +190,81 @@
 
 // PublicKey is an abstraction of different types of public keys.
 type PublicKey interface {
-	// PrivateKeyAlgo returns the name of the encryption system.
-	PrivateKeyAlgo() string
-
-	// PublicKeyAlgo returns the algorithm for the public key,
-	// which may be different from PrivateKeyAlgo for certificates.
-	PublicKeyAlgo() string
+	// Type returns the key's type, e.g. "ssh-rsa".
+	Type() string
 
 	// Marshal returns the serialized key data in SSH wire format,
-	// without the name prefix.  Callers should typically use
-	// MarshalPublicKey().
+	// with the name prefix.
 	Marshal() []byte
 
 	// Verify that sig is a signature on the given data using this
 	// key. This function will hash the data appropriately first.
-	Verify(data []byte, sigBlob []byte) bool
+	Verify(data []byte, sig *Signature) error
 }
 
-// A Signer is can create signatures that verify against a public key.
+// A Signer can create signatures that verify against a public key.
 type Signer interface {
 	// PublicKey returns an associated PublicKey instance.
 	PublicKey() PublicKey
 
 	// Sign returns raw signature for the given data. This method
 	// will apply the hash specified for the keytype to the data.
-	Sign(rand io.Reader, data []byte) ([]byte, error)
+	Sign(rand io.Reader, data []byte) (*Signature, error)
 }
 
 type rsaPublicKey rsa.PublicKey
 
-func (r *rsaPublicKey) PrivateKeyAlgo() string {
+func (r *rsaPublicKey) Type() string {
 	return "ssh-rsa"
 }
 
-func (r *rsaPublicKey) PublicKeyAlgo() string {
-	return r.PrivateKeyAlgo()
-}
-
 // parseRSA parses an RSA key according to RFC 4253, section 6.6.
-func parseRSA(in []byte) (out PublicKey, rest []byte, ok bool) {
-	key := new(rsa.PublicKey)
-
-	bigE, in, ok := parseInt(in)
-	if !ok || bigE.BitLen() > 24 {
-		return
+func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
+	var w struct {
+		E    *big.Int
+		N    *big.Int
+		Rest []byte `ssh:"rest"`
 	}
-	e := bigE.Int64()
+	if err := Unmarshal(in, &w); err != nil {
+		return nil, nil, err
+	}
+
+	if w.E.BitLen() > 24 {
+		return nil, nil, errors.New("ssh: exponent too large")
+	}
+	e := w.E.Int64()
 	if e < 3 || e&1 == 0 {
-		ok = false
-		return
+		return nil, nil, errors.New("ssh: incorrect exponent")
 	}
+
+	var key rsa.PublicKey
 	key.E = int(e)
-
-	if key.N, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	ok = true
-	return (*rsaPublicKey)(key), in, ok
+	key.N = w.N
+	return (*rsaPublicKey)(&key), w.Rest, nil
 }
 
 func (r *rsaPublicKey) Marshal() []byte {
-	// See RFC 4253, section 6.6.
 	e := new(big.Int).SetInt64(int64(r.E))
-	length := intLength(e)
-	length += intLength(r.N)
-
-	ret := make([]byte, length)
-	rest := marshalInt(ret, e)
-	marshalInt(rest, r.N)
-
-	return ret
+	wirekey := struct {
+		Name string
+		E    *big.Int
+		N    *big.Int
+	}{
+		KeyAlgoRSA,
+		e,
+		r.N,
+	}
+	return Marshal(&wirekey)
 }
 
-func (r *rsaPublicKey) Verify(data []byte, sig []byte) bool {
+func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
+	if sig.Format != r.Type() {
+		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
+	}
 	h := crypto.SHA1.New()
 	h.Write(data)
 	digest := h.Sum(nil)
-	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig) == nil
+	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob)
 }
 
 type rsaPrivateKey struct {
@@ -270,64 +275,66 @@
 	return (*rsaPublicKey)(&r.PrivateKey.PublicKey)
 }
 
-func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
 	h := crypto.SHA1.New()
 	h.Write(data)
 	digest := h.Sum(nil)
-	return rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
+	blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
+	if err != nil {
+		return nil, err
+	}
+	return &Signature{
+		Format: r.PublicKey().Type(),
+		Blob:   blob,
+	}, nil
 }
 
 type dsaPublicKey dsa.PublicKey
 
-func (r *dsaPublicKey) PrivateKeyAlgo() string {
+func (r *dsaPublicKey) Type() string {
 	return "ssh-dss"
 }
 
-func (r *dsaPublicKey) PublicKeyAlgo() string {
-	return r.PrivateKeyAlgo()
-}
-
 // parseDSA parses an DSA key according to RFC 4253, section 6.6.
-func parseDSA(in []byte) (out PublicKey, rest []byte, ok bool) {
-	key := new(dsa.PublicKey)
-
-	if key.P, in, ok = parseInt(in); !ok {
-		return
+func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
+	var w struct {
+		P, Q, G, Y *big.Int
+		Rest       []byte `ssh:"rest"`
+	}
+	if err := Unmarshal(in, &w); err != nil {
+		return nil, nil, err
 	}
 
-	if key.Q, in, ok = parseInt(in); !ok {
-		return
+	key := &dsaPublicKey{
+		Parameters: dsa.Parameters{
+			P: w.P,
+			Q: w.Q,
+			G: w.G,
+		},
+		Y: w.Y,
 	}
-
-	if key.G, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	if key.Y, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	ok = true
-	return (*dsaPublicKey)(key), in, ok
+	return key, w.Rest, nil
 }
 
-func (r *dsaPublicKey) Marshal() []byte {
-	// See RFC 4253, section 6.6.
-	length := intLength(r.P)
-	length += intLength(r.Q)
-	length += intLength(r.G)
-	length += intLength(r.Y)
+func (k *dsaPublicKey) Marshal() []byte {
+	w := struct {
+		Name       string
+		P, Q, G, Y *big.Int
+	}{
+		k.Type(),
+		k.P,
+		k.Q,
+		k.G,
+		k.Y,
+	}
 
-	ret := make([]byte, length)
-	rest := marshalInt(ret, r.P)
-	rest = marshalInt(rest, r.Q)
-	rest = marshalInt(rest, r.G)
-	marshalInt(rest, r.Y)
-
-	return ret
+	return Marshal(&w)
 }
 
-func (k *dsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
+func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
+	if sig.Format != k.Type() {
+		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
+	}
 	h := crypto.SHA1.New()
 	h.Write(data)
 	digest := h.Sum(nil)
@@ -337,12 +344,15 @@
 	// r, followed by s (which are 160-bit integers, without lengths or
 	// padding, unsigned, and in network byte order).
 	// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
-	if len(sigBlob) != 40 {
-		return false
+	if len(sig.Blob) != 40 {
+		return errors.New("ssh: DSA signature parse error")
 	}
-	r := new(big.Int).SetBytes(sigBlob[:20])
-	s := new(big.Int).SetBytes(sigBlob[20:])
-	return dsa.Verify((*dsa.PublicKey)(k), digest, r, s)
+	r := new(big.Int).SetBytes(sig.Blob[:20])
+	s := new(big.Int).SetBytes(sig.Blob[20:])
+	if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) {
+		return nil
+	}
+	return errors.New("ssh: signature did not verify")
 }
 
 type dsaPrivateKey struct {
@@ -353,7 +363,7 @@
 	return (*dsaPublicKey)(&k.PrivateKey.PublicKey)
 }
 
-func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
 	h := crypto.SHA1.New()
 	h.Write(data)
 	digest := h.Sum(nil)
@@ -363,14 +373,21 @@
 	}
 
 	sig := make([]byte, 40)
-	copy(sig[:20], r.Bytes())
-	copy(sig[20:], s.Bytes())
-	return sig, nil
+	rb := r.Bytes()
+	sb := s.Bytes()
+
+	copy(sig[20-len(rb):20], rb)
+	copy(sig[40-len(sb):], sb)
+
+	return &Signature{
+		Format: k.PublicKey().Type(),
+		Blob:   sig,
+	}, nil
 }
 
 type ecdsaPublicKey ecdsa.PublicKey
 
-func (key *ecdsaPublicKey) PrivateKeyAlgo() string {
+func (key *ecdsaPublicKey) Type() string {
 	return "ecdsa-sha2-" + key.nistID()
 }
 
@@ -387,7 +404,7 @@
 }
 
 func supportedEllipticCurve(curve elliptic.Curve) bool {
-	return (curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521())
+	return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
 }
 
 // ecHash returns the hash to match the given elliptic curve, see RFC
@@ -403,15 +420,11 @@
 	return crypto.SHA512
 }
 
-func (key *ecdsaPublicKey) PublicKeyAlgo() string {
-	return key.PrivateKeyAlgo()
-}
-
 // parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
-func parseECDSA(in []byte) (out PublicKey, rest []byte, ok bool) {
-	var identifier []byte
-	if identifier, in, ok = parseString(in); !ok {
-		return
+func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
+	identifier, in, ok := parseString(in)
+	if !ok {
+		return nil, nil, errShortRead
 	}
 
 	key := new(ecdsa.PublicKey)
@@ -424,38 +437,42 @@
 	case "nistp521":
 		key.Curve = elliptic.P521()
 	default:
-		ok = false
-		return
+		return nil, nil, errors.New("ssh: unsupported curve")
 	}
 
 	var keyBytes []byte
 	if keyBytes, in, ok = parseString(in); !ok {
-		return
+		return nil, nil, errShortRead
 	}
 
 	key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
 	if key.X == nil || key.Y == nil {
-		ok = false
-		return
+		return nil, nil, errors.New("ssh: invalid curve point")
 	}
-	return (*ecdsaPublicKey)(key), in, ok
+	return (*ecdsaPublicKey)(key), in, nil
 }
 
 func (key *ecdsaPublicKey) Marshal() []byte {
 	// See RFC 5656, section 3.1.
 	keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
+	w := struct {
+		Name string
+		ID   string
+		Key  []byte
+	}{
+		key.Type(),
+		key.nistID(),
+		keyBytes,
+	}
 
-	ID := key.nistID()
-	length := stringLength(len(ID))
-	length += stringLength(len(keyBytes))
-
-	ret := make([]byte, length)
-	r := marshalString(ret, []byte(ID))
-	r = marshalString(r, keyBytes)
-	return ret
+	return Marshal(&w)
 }
 
-func (key *ecdsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
+func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
+	if sig.Format != key.Type() {
+		return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
+	}
+
 	h := ecHash(key.Curve).New()
 	h.Write(data)
 	digest := h.Sum(nil)
@@ -464,15 +481,19 @@
 	// The ecdsa_signature_blob value has the following specific encoding:
 	//    mpint    r
 	//    mpint    s
-	r, rest, ok := parseInt(sigBlob)
-	if !ok {
-		return false
+	var ecSig struct {
+		R *big.Int
+		S *big.Int
 	}
-	s, rest, ok := parseInt(rest)
-	if !ok || len(rest) > 0 {
-		return false
+
+	if err := Unmarshal(sig.Blob, &ecSig); err != nil {
+		return err
 	}
-	return ecdsa.Verify((*ecdsa.PublicKey)(key), digest, r, s)
+
+	if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
+		return nil
+	}
+	return errors.New("ssh: signature did not verify")
 }
 
 type ecdsaPrivateKey struct {
@@ -483,7 +504,7 @@
 	return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey)
 }
 
-func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
 	h := ecHash(k.PrivateKey.PublicKey.Curve).New()
 	h.Write(data)
 	digest := h.Sum(nil)
@@ -495,10 +516,13 @@
 	sig := make([]byte, intLength(r)+intLength(s))
 	rest := marshalInt(sig, r)
 	marshalInt(rest, s)
-	return sig, nil
+	return &Signature{
+		Format: k.PublicKey().Type(),
+		Blob:   sig,
+	}, nil
 }
 
-// NewPrivateKey takes a pointer to rsa, dsa or ecdsa PrivateKey
+// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
 // returns a corresponding Signer instance. EC keys should use P256,
 // P384 or P521.
 func NewSignerFromKey(k interface{}) (Signer, error) {
@@ -540,54 +564,49 @@
 	return sshKey, nil
 }
 
-// ParsePublicKey parses a PEM encoded private key. It supports
-// PKCS#1, RSA, DSA and ECDSA private keys.
+// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
+// the same keys as ParseRawPrivateKey.
 func ParsePrivateKey(pemBytes []byte) (Signer, error) {
+	key, err := ParseRawPrivateKey(pemBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	return NewSignerFromKey(key)
+}
+
+// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
+// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
+func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
 	block, _ := pem.Decode(pemBytes)
 	if block == nil {
 		return nil, errors.New("ssh: no key found")
 	}
 
-	var rawkey interface{}
 	switch block.Type {
 	case "RSA PRIVATE KEY":
-		rsa, err := x509.ParsePKCS1PrivateKey(block.Bytes)
-		if err != nil {
-			return nil, err
-		}
-		rawkey = rsa
+		return x509.ParsePKCS1PrivateKey(block.Bytes)
 	case "EC PRIVATE KEY":
-		ec, err := x509.ParseECPrivateKey(block.Bytes)
-		if err != nil {
-			return nil, err
-		}
-		rawkey = ec
+		return x509.ParseECPrivateKey(block.Bytes)
 	case "DSA PRIVATE KEY":
-		ec, err := parseDSAPrivate(block.Bytes)
-		if err != nil {
-			return nil, err
-		}
-		rawkey = ec
+		return ParseDSAPrivateKey(block.Bytes)
 	default:
 		return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
 	}
-
-	return NewSignerFromKey(rawkey)
 }
 
-// parseDSAPrivate parses a DSA key in ASN.1 DER encoding, as
-// documented in the OpenSSL DSA manpage.
-// TODO(hanwen): move this in to crypto/x509 after the Go 1.2 freeze.
-func parseDSAPrivate(p []byte) (*dsa.PrivateKey, error) {
-	k := struct {
+// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
+// specified by the OpenSSL DSA man page.
+func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
+	var k struct {
 		Version int
 		P       *big.Int
 		Q       *big.Int
 		G       *big.Int
 		Priv    *big.Int
 		Pub     *big.Int
-	}{}
-	rest, err := asn1.Unmarshal(p, &k)
+	}
+	rest, err := asn1.Unmarshal(der, &k)
 	if err != nil {
 		return nil, errors.New("ssh: failed to parse DSA key: " + err.Error())
 	}
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index 3c4b735..cd49565 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -1,66 +1,25 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package ssh
 
 import (
+	"bytes"
 	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rsa"
+	"encoding/base64"
+	"fmt"
 	"reflect"
 	"strings"
 	"testing"
+
+	"code.google.com/p/go.crypto/ssh/testdata"
 )
 
-var (
-	ecdsaKey    Signer
-	ecdsa384Key Signer
-	ecdsa521Key Signer
-	testCertKey Signer
-)
-
-type testSigner struct {
-	Signer
-	pub PublicKey
-}
-
-func (ts *testSigner) PublicKey() PublicKey {
-	if ts.pub != nil {
-		return ts.pub
-	}
-	return ts.Signer.PublicKey()
-}
-
-func init() {
-	raw256, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	ecdsaKey, _ = NewSignerFromKey(raw256)
-
-	raw384, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
-	ecdsa384Key, _ = NewSignerFromKey(raw384)
-
-	raw521, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
-	ecdsa521Key, _ = NewSignerFromKey(raw521)
-
-	// Create a cert and sign it for use in tests.
-	testCert := &OpenSSHCertV01{
-		Nonce:           []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
-		Key:             ecdsaKey.PublicKey(),
-		ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
-		ValidAfter:      0,                              // unix epoch
-		ValidBefore:     maxUint64,                      // The end of currently representable time.
-		Reserved:        []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
-		SignatureKey:    rsaKey.PublicKey(),
-	}
-	sigBytes, _ := rsaKey.Sign(rand.Reader, testCert.BytesForSigning())
-	testCert.Signature = &signature{
-		Format: testCert.SignatureKey.PublicKeyAlgo(),
-		Blob:   sigBytes,
-	}
-	testCertKey = &testSigner{
-		Signer: ecdsaKey,
-		pub:    testCert,
-	}
-}
-
 func rawKey(pub PublicKey) interface{} {
 	switch k := pub.(type) {
 	case *rsaPublicKey:
@@ -69,23 +28,18 @@
 		return (*dsa.PublicKey)(k)
 	case *ecdsaPublicKey:
 		return (*ecdsa.PublicKey)(k)
-	case *OpenSSHCertV01:
+	case *Certificate:
 		return k
 	}
 	panic("unknown key type")
 }
 
 func TestKeyMarshalParse(t *testing.T) {
-	keys := []Signer{rsaKey, dsaKey, ecdsaKey, ecdsa384Key, ecdsa521Key, testCertKey}
-	for _, priv := range keys {
+	for _, priv := range testSigners {
 		pub := priv.PublicKey()
-		roundtrip, rest, ok := ParsePublicKey(MarshalPublicKey(pub))
-		if !ok {
-			t.Errorf("ParsePublicKey(%T) failed", pub)
-		}
-
-		if len(rest) > 0 {
-			t.Errorf("ParsePublicKey(%T): trailing junk", pub)
+		roundtrip, err := ParsePublicKey(pub.Marshal())
+		if err != nil {
+			t.Errorf("ParsePublicKey(%T): %v", pub, err)
 		}
 
 		k1 := rawKey(pub)
@@ -113,9 +67,12 @@
 }
 
 func TestNewPublicKey(t *testing.T) {
-	keys := []Signer{rsaKey, dsaKey, ecdsaKey}
-	for _, k := range keys {
+	for _, k := range testSigners {
 		raw := rawKey(k.PublicKey())
+		// Skip certificates, as NewPublicKey does not support them.
+		if _, ok := raw.(*Certificate); ok {
+			continue
+		}
 		pub, err := NewPublicKey(raw)
 		if err != nil {
 			t.Errorf("NewPublicKey(%#v): %v", raw, err)
@@ -127,8 +84,7 @@
 }
 
 func TestKeySignVerify(t *testing.T) {
-	keys := []Signer{rsaKey, dsaKey, ecdsaKey, testCertKey}
-	for _, priv := range keys {
+	for _, priv := range testSigners {
 		pub := priv.PublicKey()
 
 		data := []byte("sign me")
@@ -137,19 +93,20 @@
 			t.Fatalf("Sign(%T): %v", priv, err)
 		}
 
-		if !pub.Verify(data, sig) {
-			t.Errorf("publicKey.Verify(%T) failed", priv)
+		if err := pub.Verify(data, sig); err != nil {
+			t.Errorf("publicKey.Verify(%T): %v", priv, err)
+		}
+		sig.Blob[5]++
+		if err := pub.Verify(data, sig); err == nil {
+			t.Errorf("publicKey.Verify on broken sig did not fail")
 		}
 	}
 }
 
 func TestParseRSAPrivateKey(t *testing.T) {
-	key, err := ParsePrivateKey([]byte(testServerPrivateKey))
-	if err != nil {
-		t.Fatalf("ParsePrivateKey: %v", err)
-	}
+	key := testPrivateKeys["rsa"]
 
-	rsa, ok := key.(*rsaPrivateKey)
+	rsa, ok := key.(*rsa.PrivateKey)
 	if !ok {
 		t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
 	}
@@ -160,21 +117,11 @@
 }
 
 func TestParseECPrivateKey(t *testing.T) {
-	// Taken from the data in test/ .
-	pem := []byte(`-----BEGIN EC PRIVATE KEY-----
-MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
-AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
-6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
------END EC PRIVATE KEY-----`)
+	key := testPrivateKeys["ecdsa"]
 
-	key, err := ParsePrivateKey(pem)
-	if err != nil {
-		t.Fatalf("ParsePrivateKey: %v", err)
-	}
-
-	ecKey, ok := key.(*ecdsaPrivateKey)
+	ecKey, ok := key.(*ecdsa.PrivateKey)
 	if !ok {
-		t.Fatalf("got %T, want *ecdsaPrivateKey", ecKey)
+		t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey)
 	}
 
 	if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
@@ -182,22 +129,11 @@
 	}
 }
 
-// ssh-keygen -t dsa -f /tmp/idsa.pem
-var dsaPEM = `-----BEGIN DSA PRIVATE KEY-----
-MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
-lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
-EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
-nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
-2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
-juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
-FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
-DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
-nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
-Fmsr0W6fHB9nhS4/UXM8
------END DSA PRIVATE KEY-----`
-
 func TestParseDSA(t *testing.T) {
-	s, err := ParsePrivateKey([]byte(dsaPEM))
+	// We actually exercise the ParsePrivateKey codepath here, as opposed to
+	// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
+	// uses.
+	s, err := ParsePrivateKey(testdata.PEMBytes["dsa"])
 	if err != nil {
 		t.Fatalf("ParsePrivateKey returned error: %s", err)
 	}
@@ -208,7 +144,163 @@
 		t.Fatalf("dsa.Sign: %v", err)
 	}
 
-	if !s.PublicKey().Verify(data, sig) {
-		t.Error("Verify failed.")
+	if err := s.PublicKey().Verify(data, sig); err != nil {
+		t.Errorf("Verify failed: %v", err)
+	}
+}
+
+// Tests for authorized_keys parsing.
+
+// getTestKey returns a public key, and its base64 encoding.
+func getTestKey() (PublicKey, string) {
+	k := testPublicKeys["rsa"]
+
+	b := &bytes.Buffer{}
+	e := base64.NewEncoder(base64.StdEncoding, b)
+	e.Write(k.Marshal())
+	e.Close()
+
+	return k, b.String()
+}
+
+func TestMarshalParsePublicKey(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized)
+
+	authKeys := MarshalAuthorizedKey(pub)
+	actualFields := strings.Fields(string(authKeys))
+	if len(actualFields) == 0 {
+		t.Fatalf("failed authKeys: %v", authKeys)
+	}
+
+	// drop the comment
+	expectedFields := strings.Fields(line)[0:2]
+
+	if !reflect.DeepEqual(actualFields, expectedFields) {
+		t.Errorf("got %v, expected %v", actualFields, expectedFields)
+	}
+
+	actPub, _, _, _, err := ParseAuthorizedKey([]byte(line))
+	if err != nil {
+		t.Fatalf("cannot parse %v: %v", line, err)
+	}
+	if !reflect.DeepEqual(actPub, pub) {
+		t.Errorf("got %v, expected %v", actPub, pub)
+	}
+}
+
+type authResult struct {
+	pubKey   PublicKey
+	options  []string
+	comments string
+	rest     string
+	ok       bool
+}
+
+func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
+	rest := authKeys
+	var values []authResult
+	for len(rest) > 0 {
+		var r authResult
+		var err error
+		r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
+		r.ok = (err == nil)
+		t.Log(err)
+		r.rest = string(rest)
+		values = append(values, r)
+	}
+
+	if !reflect.DeepEqual(values, expected) {
+		t.Errorf("got %#v, expected %#v", values, expected)
+	}
+}
+
+func TestAuthorizedKeyBasic(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	line := "ssh-rsa " + pubSerialized + " user@host"
+	testAuthorizedKeys(t, []byte(line),
+		[]authResult{
+			{pub, nil, "user@host", "", true},
+		})
+}
+
+func TestAuth(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	authWithOptions := []string{
+		`# comments to ignore before any keys...`,
+		``,
+		`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`,
+		`# comments to ignore, along with a blank line`,
+		``,
+		`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`,
+		``,
+		`# more comments, plus a invalid entry`,
+		`ssh-rsa data-that-will-not-parse user@host3`,
+	}
+	for _, eol := range []string{"\n", "\r\n"} {
+		authOptions := strings.Join(authWithOptions, eol)
+		rest2 := strings.Join(authWithOptions[3:], eol)
+		rest3 := strings.Join(authWithOptions[6:], eol)
+		testAuthorizedKeys(t, []byte(authOptions), []authResult{
+			{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
+			{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
+			{nil, nil, "", "", false},
+		})
+	}
+}
+
+func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
+	testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
+		{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
+	})
+}
+
+func TestAuthWithQuotedCommaInEnv(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + `   user@host`)
+	testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
+		{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
+	})
+}
+
+func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + `   user@host`)
+	authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
+	testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
+		{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
+	})
+
+	testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
+		{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
+	})
+}
+
+func TestAuthWithInvalidSpace(t *testing.T) {
+	_, pubSerialized := getTestKey()
+	authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
+#more to follow but still no valid keys`)
+	testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
+		{nil, nil, "", "", false},
+	})
+}
+
+func TestAuthWithMissingQuote(t *testing.T) {
+	pub, pubSerialized := getTestKey()
+	authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
+env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
+
+	testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
+		{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
+	})
+}
+
+func TestInvalidEntry(t *testing.T) {
+	authInvalid := []byte(`ssh-rsa`)
+	_, _, _, _, err := ParseAuthorizedKey(authInvalid)
+	if err == nil {
+		t.Errorf("got valid entry for %q", authInvalid)
 	}
 }
diff --git a/ssh/mac.go b/ssh/mac.go
index 6862d3e..aff4042 100644
--- a/ssh/mac.go
+++ b/ssh/mac.go
@@ -43,11 +43,6 @@
 
 func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
 
-// Specifies a default set of MAC algorithms and a preference order.
-// This is based on RFC 4253, section 6.4, with the removal of the
-// hmac-md5 variants as they have reached the end of their useful life.
-var DefaultMACOrder = []string{"hmac-sha1", "hmac-sha1-96"}
-
 var macModes = map[string]*macMode{
 	"hmac-sha1": {20, func(key []byte) hash.Hash {
 		return hmac.New(sha1.New, key)
diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go
index ec1b854..69217a4 100644
--- a/ssh/mempipe_test.go
+++ b/ssh/mempipe_test.go
@@ -36,17 +36,23 @@
 	}
 }
 
-func (t *memTransport) Close() error {
-	t.write.Lock()
-	defer t.write.Unlock()
-	if t.write.eof {
+func (t *memTransport) closeSelf() error {
+	t.Lock()
+	defer t.Unlock()
+	if t.eof {
 		return io.EOF
 	}
-	t.write.eof = true
-	t.write.Cond.Broadcast()
+	t.eof = true
+	t.Cond.Broadcast()
 	return nil
 }
 
+func (t *memTransport) Close() error {
+	err := t.write.closeSelf()
+	t.closeSelf()
+	return err
+}
+
 func (t *memTransport) writePacket(p []byte) error {
 	t.write.Lock()
 	defer t.write.Unlock()
diff --git a/ssh/messages.go b/ssh/messages.go
index 94c3ea0..f9e44bb 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -7,58 +7,25 @@
 import (
 	"bytes"
 	"encoding/binary"
+	"errors"
+	"fmt"
 	"io"
 	"math/big"
 	"reflect"
+	"strconv"
 )
 
 // These are SSH message type numbers. They are scattered around several
 // documents but many were taken from [SSH-PARAMETERS].
 const (
-	msgDisconnect     = 1
-	msgIgnore         = 2
-	msgUnimplemented  = 3
-	msgDebug          = 4
-	msgServiceRequest = 5
-	msgServiceAccept  = 6
-
-	msgKexInit = 20
-	msgNewKeys = 21
-
-	// Diffie-Helman
-	msgKexDHInit  = 30
-	msgKexDHReply = 31
-
-	msgKexECDHInit  = 30
-	msgKexECDHReply = 31
+	msgIgnore        = 2
+	msgUnimplemented = 3
+	msgDebug         = 4
+	msgNewKeys       = 21
 
 	// Standard authentication messages
-	msgUserAuthRequest  = 50
-	msgUserAuthFailure  = 51
-	msgUserAuthSuccess  = 52
-	msgUserAuthBanner   = 53
-	msgUserAuthPubKeyOk = 60
-
-	// Method specific messages
-	msgUserAuthInfoRequest  = 60
-	msgUserAuthInfoResponse = 61
-
-	msgGlobalRequest  = 80
-	msgRequestSuccess = 81
-	msgRequestFailure = 82
-
-	// Channel manipulation
-	msgChannelOpen         = 90
-	msgChannelOpenConfirm  = 91
-	msgChannelOpenFailure  = 92
-	msgChannelWindowAdjust = 93
-	msgChannelData         = 94
-	msgChannelExtendedData = 95
-	msgChannelEOF          = 96
-	msgChannelClose        = 97
-	msgChannelRequest      = 98
-	msgChannelSuccess      = 99
-	msgChannelFailure      = 100
+	msgUserAuthSuccess = 52
+	msgUserAuthBanner  = 53
 )
 
 // SSH messages:
@@ -69,15 +36,25 @@
 // ssh tag of "rest" receives the remainder of a packet when unmarshaling.
 
 // See RFC 4253, section 11.1.
+const msgDisconnect = 1
+
+// disconnectMsg is the message that signals a disconnect. It is also
+// the error type returned from mux.Wait()
 type disconnectMsg struct {
-	Reason   uint32
+	Reason   uint32 `sshtype:"1"`
 	Message  string
 	Language string
 }
 
+func (d *disconnectMsg) Error() string {
+	return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
+}
+
 // See RFC 4253, section 7.1.
+const msgKexInit = 20
+
 type kexInitMsg struct {
-	Cookie                  [16]byte
+	Cookie                  [16]byte `sshtype:"20"`
 	KexAlgos                []string
 	ServerHostKeyAlgos      []string
 	CiphersClientServer     []string
@@ -93,53 +70,74 @@
 }
 
 // See RFC 4253, section 8.
+
+// Diffie-Helman
+const msgKexDHInit = 30
+
 type kexDHInitMsg struct {
-	X *big.Int
+	X *big.Int `sshtype:"30"`
 }
 
+const msgKexECDHInit = 30
+
 type kexECDHInitMsg struct {
-	ClientPubKey []byte
+	ClientPubKey []byte `sshtype:"30"`
 }
 
+const msgKexECDHReply = 31
+
 type kexECDHReplyMsg struct {
-	HostKey         []byte
+	HostKey         []byte `sshtype:"31"`
 	EphemeralPubKey []byte
 	Signature       []byte
 }
 
+const msgKexDHReply = 31
+
 type kexDHReplyMsg struct {
-	HostKey   []byte
+	HostKey   []byte `sshtype:"31"`
 	Y         *big.Int
 	Signature []byte
 }
 
 // See RFC 4253, section 10.
+const msgServiceRequest = 5
+
 type serviceRequestMsg struct {
-	Service string
+	Service string `sshtype:"5"`
 }
 
 // See RFC 4253, section 10.
+const msgServiceAccept = 6
+
 type serviceAcceptMsg struct {
-	Service string
+	Service string `sshtype:"6"`
 }
 
 // See RFC 4252, section 5.
+const msgUserAuthRequest = 50
+
 type userAuthRequestMsg struct {
-	User    string
+	User    string `sshtype:"50"`
 	Service string
 	Method  string
 	Payload []byte `ssh:"rest"`
 }
 
 // See RFC 4252, section 5.1
+const msgUserAuthFailure = 51
+
 type userAuthFailureMsg struct {
-	Methods        []string
+	Methods        []string `sshtype:"51"`
 	PartialSuccess bool
 }
 
 // See RFC 4256, section 3.2
+const msgUserAuthInfoRequest = 60
+const msgUserAuthInfoResponse = 61
+
 type userAuthInfoRequestMsg struct {
-	User               string
+	User               string `sshtype:"60"`
 	Instruction        string
 	DeprecatedLanguage string
 	NumPrompts         uint32
@@ -147,17 +145,24 @@
 }
 
 // See RFC 4254, section 5.1.
+const msgChannelOpen = 90
+
 type channelOpenMsg struct {
-	ChanType         string
+	ChanType         string `sshtype:"90"`
 	PeersId          uint32
 	PeersWindow      uint32
 	MaxPacketSize    uint32
 	TypeSpecificData []byte `ssh:"rest"`
 }
 
+const msgChannelExtendedData = 95
+const msgChannelData = 94
+
 // See RFC 4254, section 5.1.
+const msgChannelOpenConfirm = 91
+
 type channelOpenConfirmMsg struct {
-	PeersId          uint32
+	PeersId          uint32 `sshtype:"91"`
 	MyId             uint32
 	MyWindow         uint32
 	MaxPacketSize    uint32
@@ -165,172 +170,239 @@
 }
 
 // See RFC 4254, section 5.1.
+const msgChannelOpenFailure = 92
+
 type channelOpenFailureMsg struct {
-	PeersId  uint32
+	PeersId  uint32 `sshtype:"92"`
 	Reason   RejectionReason
 	Message  string
 	Language string
 }
 
+const msgChannelRequest = 98
+
 type channelRequestMsg struct {
-	PeersId             uint32
+	PeersId             uint32 `sshtype:"98"`
 	Request             string
 	WantReply           bool
 	RequestSpecificData []byte `ssh:"rest"`
 }
 
 // See RFC 4254, section 5.4.
+const msgChannelSuccess = 99
+
 type channelRequestSuccessMsg struct {
-	PeersId uint32
+	PeersId uint32 `sshtype:"99"`
 }
 
 // See RFC 4254, section 5.4.
+const msgChannelFailure = 100
+
 type channelRequestFailureMsg struct {
-	PeersId uint32
+	PeersId uint32 `sshtype:"100"`
 }
 
 // See RFC 4254, section 5.3
+const msgChannelClose = 97
+
 type channelCloseMsg struct {
-	PeersId uint32
+	PeersId uint32 `sshtype:"97"`
 }
 
 // See RFC 4254, section 5.3
+const msgChannelEOF = 96
+
 type channelEOFMsg struct {
-	PeersId uint32
+	PeersId uint32 `sshtype:"96"`
 }
 
 // See RFC 4254, section 4
+const msgGlobalRequest = 80
+
 type globalRequestMsg struct {
-	Type      string
+	Type      string `sshtype:"80"`
 	WantReply bool
+	Data      []byte `ssh:"rest"`
 }
 
 // See RFC 4254, section 4
+const msgRequestSuccess = 81
+
 type globalRequestSuccessMsg struct {
-	Data []byte `ssh:"rest"`
+	Data []byte `ssh:"rest" sshtype:"81"`
 }
 
 // See RFC 4254, section 4
+const msgRequestFailure = 82
+
 type globalRequestFailureMsg struct {
-	Data []byte `ssh:"rest"`
+	Data []byte `ssh:"rest" sshtype:"82"`
 }
 
 // See RFC 4254, section 5.2
+const msgChannelWindowAdjust = 93
+
 type windowAdjustMsg struct {
-	PeersId         uint32
+	PeersId         uint32 `sshtype:"93"`
 	AdditionalBytes uint32
 }
 
 // See RFC 4252, section 7
+const msgUserAuthPubKeyOk = 60
+
 type userAuthPubKeyOkMsg struct {
-	Algo   string
-	PubKey string
+	Algo   string `sshtype:"60"`
+	PubKey []byte
 }
 
-// unmarshal parses the SSH wire data in packet into out using
-// reflection. expectedType, if non-zero, is the SSH message type that
-// the packet is expected to start with.  unmarshal either returns nil
-// on success, or a ParseError or UnexpectedMessageError on error.
-func unmarshal(out interface{}, packet []byte, expectedType uint8) error {
-	if len(packet) == 0 {
-		return ParseError{expectedType}
+// typeTag returns the type byte for the given type. The type should
+// be struct.
+func typeTag(structType reflect.Type) byte {
+	var tag byte
+	var tagStr string
+	tagStr = structType.Field(0).Tag.Get("sshtype")
+	i, err := strconv.Atoi(tagStr)
+	if err == nil {
+		tag = byte(i)
 	}
-	if expectedType > 0 {
-		if packet[0] != expectedType {
-			return UnexpectedMessageError{expectedType, packet[0]}
-		}
-		packet = packet[1:]
-	}
+	return tag
+}
 
+func fieldError(t reflect.Type, field int, problem string) error {
+	if problem != "" {
+		problem = ": " + problem
+	}
+	return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
+}
+
+var errShortRead = errors.New("ssh: short read")
+
+// Unmarshal parses data in SSH wire format into a structure. The out
+// argument should be a pointer to struct. If the first member of the
+// struct has the "sshtype" tag set to a number in decimal, the packet
+// must start that number.  In case of error, Unmarshal returns a
+// ParseError or UnexpectedMessageError.
+func Unmarshal(data []byte, out interface{}) error {
 	v := reflect.ValueOf(out).Elem()
 	structType := v.Type()
+	expectedType := typeTag(structType)
+	if len(data) == 0 {
+		return parseError(expectedType)
+	}
+	if expectedType > 0 {
+		if data[0] != expectedType {
+			return unexpectedMessageError(expectedType, data[0])
+		}
+		data = data[1:]
+	}
+
 	var ok bool
 	for i := 0; i < v.NumField(); i++ {
 		field := v.Field(i)
 		t := field.Type()
 		switch t.Kind() {
 		case reflect.Bool:
-			if len(packet) < 1 {
-				return ParseError{expectedType}
+			if len(data) < 1 {
+				return errShortRead
 			}
-			field.SetBool(packet[0] != 0)
-			packet = packet[1:]
+			field.SetBool(data[0] != 0)
+			data = data[1:]
 		case reflect.Array:
 			if t.Elem().Kind() != reflect.Uint8 {
-				panic("array of non-uint8")
+				return fieldError(structType, i, "array of unsupported type")
 			}
-			if len(packet) < t.Len() {
-				return ParseError{expectedType}
+			if len(data) < t.Len() {
+				return errShortRead
 			}
 			for j, n := 0, t.Len(); j < n; j++ {
-				field.Index(j).Set(reflect.ValueOf(packet[j]))
+				field.Index(j).Set(reflect.ValueOf(data[j]))
 			}
-			packet = packet[t.Len():]
+			data = data[t.Len():]
+		case reflect.Uint64:
+			var u64 uint64
+			if u64, data, ok = parseUint64(data); !ok {
+				return errShortRead
+			}
+			field.SetUint(u64)
 		case reflect.Uint32:
 			var u32 uint32
-			if u32, packet, ok = parseUint32(packet); !ok {
-				return ParseError{expectedType}
+			if u32, data, ok = parseUint32(data); !ok {
+				return errShortRead
 			}
 			field.SetUint(uint64(u32))
+		case reflect.Uint8:
+			if len(data) < 1 {
+				return errShortRead
+			}
+			field.SetUint(uint64(data[0]))
+			data = data[1:]
 		case reflect.String:
 			var s []byte
-			if s, packet, ok = parseString(packet); !ok {
-				return ParseError{expectedType}
+			if s, data, ok = parseString(data); !ok {
+				return fieldError(structType, i, "")
 			}
 			field.SetString(string(s))
 		case reflect.Slice:
 			switch t.Elem().Kind() {
 			case reflect.Uint8:
 				if structType.Field(i).Tag.Get("ssh") == "rest" {
-					field.Set(reflect.ValueOf(packet))
-					packet = nil
+					field.Set(reflect.ValueOf(data))
+					data = nil
 				} else {
 					var s []byte
-					if s, packet, ok = parseString(packet); !ok {
-						return ParseError{expectedType}
+					if s, data, ok = parseString(data); !ok {
+						return errShortRead
 					}
 					field.Set(reflect.ValueOf(s))
 				}
 			case reflect.String:
 				var nl []string
-				if nl, packet, ok = parseNameList(packet); !ok {
-					return ParseError{expectedType}
+				if nl, data, ok = parseNameList(data); !ok {
+					return errShortRead
 				}
 				field.Set(reflect.ValueOf(nl))
 			default:
-				panic("slice of unknown type")
+				return fieldError(structType, i, "slice of unsupported type")
 			}
 		case reflect.Ptr:
 			if t == bigIntType {
 				var n *big.Int
-				if n, packet, ok = parseInt(packet); !ok {
-					return ParseError{expectedType}
+				if n, data, ok = parseInt(data); !ok {
+					return errShortRead
 				}
 				field.Set(reflect.ValueOf(n))
 			} else {
-				panic("pointer to unknown type")
+				return fieldError(structType, i, "pointer to unsupported type")
 			}
 		default:
-			panic("unknown type")
+			return fieldError(structType, i, "unsupported type")
 		}
 	}
 
-	if len(packet) != 0 {
-		return ParseError{expectedType}
+	if len(data) != 0 {
+		return parseError(expectedType)
 	}
 
 	return nil
 }
 
-// marshal serializes the message in msg. The given message type is
-// prepended if it is non-zero.
-func marshal(msgType uint8, msg interface{}) []byte {
+// Marshal serializes the message in msg to SSH wire format.  The msg
+// argument should be a struct or pointer to struct. If the first
+// member has the "sshtype" tag set to a number in decimal, that
+// number is prepended to the result. If the last of member has the
+// "ssh" tag set to "rest", its contents are appended to the output.
+func Marshal(msg interface{}) []byte {
 	out := make([]byte, 0, 64)
+	return marshalStruct(out, msg)
+}
+
+func marshalStruct(out []byte, msg interface{}) []byte {
+	v := reflect.Indirect(reflect.ValueOf(msg))
+	msgType := typeTag(v.Type())
 	if msgType > 0 {
 		out = append(out, msgType)
 	}
 
-	v := reflect.ValueOf(msg)
 	for i, n := 0, v.NumField(); i < n; i++ {
 		field := v.Field(i)
 		switch t := field.Type(); t.Kind() {
@@ -342,13 +414,17 @@
 			out = append(out, v)
 		case reflect.Array:
 			if t.Elem().Kind() != reflect.Uint8 {
-				panic("array of non-uint8")
+				panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
 			}
 			for j, l := 0, t.Len(); j < l; j++ {
 				out = append(out, uint8(field.Index(j).Uint()))
 			}
 		case reflect.Uint32:
 			out = appendU32(out, uint32(field.Uint()))
+		case reflect.Uint64:
+			out = appendU64(out, uint64(field.Uint()))
+		case reflect.Uint8:
+			out = append(out, uint8(field.Uint()))
 		case reflect.String:
 			s := field.String()
 			out = appendInt(out, len(s))
@@ -375,7 +451,7 @@
 					binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
 				}
 			default:
-				panic("slice of unknown type")
+				panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
 			}
 		case reflect.Ptr:
 			if t == bigIntType {
@@ -393,7 +469,7 @@
 				out = out[:oldLength+needed]
 				marshalInt(out[oldLength:], n)
 			} else {
-				panic("pointer to unknown type")
+				panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
 			}
 		}
 	}
@@ -477,17 +553,6 @@
 	return binary.BigEndian.Uint64(in), in[8:], true
 }
 
-func nameListLength(namelist []string) int {
-	length := 4 /* uint32 length prefix */
-	for i, name := range namelist {
-		if i != 0 {
-			length++ /* comma */
-		}
-		length += len(name)
-	}
-	return length
-}
-
 func intLength(n *big.Int) int {
 	length := 4 /* length bytes */
 	if n.Sign() < 0 {
@@ -650,9 +715,9 @@
 	case msgChannelFailure:
 		msg = new(channelRequestFailureMsg)
 	default:
-		return nil, UnexpectedMessageError{0, packet[0]}
+		return nil, unexpectedMessageError(0, packet[0])
 	}
-	if err := unmarshal(msg, packet, packet[0]); err != nil {
+	if err := Unmarshal(packet, msg); err != nil {
 		return nil, err
 	}
 	return msg, nil
diff --git a/ssh/messages_test.go b/ssh/messages_test.go
index ec1d7be..f14c8a2 100644
--- a/ssh/messages_test.go
+++ b/ssh/messages_test.go
@@ -5,6 +5,7 @@
 package ssh
 
 import (
+	"bytes"
 	"math/big"
 	"math/rand"
 	"reflect"
@@ -32,48 +33,62 @@
 	}
 }
 
-var messageTypes = []interface{}{
-	&kexInitMsg{},
-	&kexDHInitMsg{},
-	&serviceRequestMsg{},
-	&serviceAcceptMsg{},
-	&userAuthRequestMsg{},
-	&channelOpenMsg{},
-	&channelOpenConfirmMsg{},
-	&channelOpenFailureMsg{},
-	&channelRequestMsg{},
-	&channelRequestSuccessMsg{},
+type msgAllTypes struct {
+	Bool    bool `sshtype:"21"`
+	Array   [16]byte
+	Uint64  uint64
+	Uint32  uint32
+	Uint8   uint8
+	String  string
+	Strings []string
+	Bytes   []byte
+	Int     *big.Int
+	Rest    []byte `ssh:"rest"`
+}
+
+func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
+	m := &msgAllTypes{}
+	m.Bool = rand.Intn(2) == 1
+	randomBytes(m.Array[:], rand)
+	m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
+	m.Uint32 = uint32(rand.Intn(1 << 32))
+	m.Uint8 = uint8(rand.Intn(1 << 8))
+	m.String = string(m.Array[:])
+	m.Strings = randomNameList(rand)
+	m.Bytes = m.Array[:]
+	m.Int = randomInt(rand)
+	m.Rest = m.Array[:]
+	return reflect.ValueOf(m)
 }
 
 func TestMarshalUnmarshal(t *testing.T) {
 	rand := rand.New(rand.NewSource(0))
-	for i, iface := range messageTypes {
-		ty := reflect.ValueOf(iface).Type()
+	iface := &msgAllTypes{}
+	ty := reflect.ValueOf(iface).Type()
 
-		n := 100
-		if testing.Short() {
-			n = 5
+	n := 100
+	if testing.Short() {
+		n = 5
+	}
+	for j := 0; j < n; j++ {
+		v, ok := quick.Value(ty, rand)
+		if !ok {
+			t.Errorf("failed to create value")
+			break
 		}
-		for j := 0; j < n; j++ {
-			v, ok := quick.Value(ty, rand)
-			if !ok {
-				t.Errorf("#%d: failed to create value", i)
-				break
-			}
 
-			m1 := v.Elem().Interface()
-			m2 := iface
+		m1 := v.Elem().Interface()
+		m2 := iface
 
-			marshaled := marshal(msgIgnore, m1)
-			if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
-				t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
-				break
-			}
+		marshaled := Marshal(m1)
+		if err := Unmarshal(marshaled, m2); err != nil {
+			t.Errorf("Unmarshal %#v: %s", m1, err)
+			break
+		}
 
-			if !reflect.DeepEqual(v.Interface(), m2) {
-				t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
-				break
-			}
+		if !reflect.DeepEqual(v.Interface(), m2) {
+			t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
+			break
 		}
 	}
 }
@@ -81,33 +96,37 @@
 func TestUnmarshalEmptyPacket(t *testing.T) {
 	var b []byte
 	var m channelRequestSuccessMsg
-	err := unmarshal(&m, b, msgChannelRequest)
-	want := ParseError{msgChannelRequest}
-	if _, ok := err.(ParseError); !ok {
-		t.Fatalf("got %T, want %T", err, want)
-	}
-	if got := err.(ParseError); want != got {
-		t.Fatal("got %#v, want %#v", got, want)
+	if err := Unmarshal(b, &m); err == nil {
+		t.Fatalf("unmarshal of empty slice succeeded")
 	}
 }
 
 func TestUnmarshalUnexpectedPacket(t *testing.T) {
 	type S struct {
-		I uint32
+		I uint32 `sshtype:"43"`
 		S string
 		B bool
 	}
 
-	s := S{42, "hello", true}
-	packet := marshal(42, s)
+	s := S{11, "hello", true}
+	packet := Marshal(s)
+	packet[0] = 42
 	roundtrip := S{}
-	err := unmarshal(&roundtrip, packet, 43)
+	err := Unmarshal(packet, &roundtrip)
 	if err == nil {
 		t.Fatal("expected error, not nil")
 	}
-	want := UnexpectedMessageError{43, 42}
-	if got, ok := err.(UnexpectedMessageError); !ok || want != got {
-		t.Fatal("expected %q, got %q", want, got)
+}
+
+func TestMarshalPtr(t *testing.T) {
+	s := struct {
+		S string
+	}{"hello"}
+
+	m1 := Marshal(s)
+	m2 := Marshal(&s)
+	if !bytes.Equal(m1, m2) {
+		t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
 	}
 }
 
@@ -119,9 +138,9 @@
 	}
 
 	s := S{42, "hello", true}
-	packet := marshal(0, s)
+	packet := Marshal(s)
 	roundtrip := S{}
-	unmarshal(&roundtrip, packet, 0)
+	Unmarshal(packet, &roundtrip)
 
 	if !reflect.DeepEqual(s, roundtrip) {
 		t.Errorf("got %#v, want %#v", roundtrip, s)
@@ -133,7 +152,7 @@
 		I uint32
 	}
 	s := S2{42}
-	packet := marshal(0, s)
+	packet := Marshal(s)
 	i, rest, ok := parseUint32(packet)
 	if len(rest) > 0 || !ok {
 		t.Errorf("parseInt(%q): parse error", packet)
@@ -190,43 +209,36 @@
 	return reflect.ValueOf(dhi)
 }
 
-// TODO(dfc) maybe this can be removed in the future if testing/quick can handle
-// derived basic types.
-func (RejectionReason) Generate(rand *rand.Rand, size int) reflect.Value {
-	m := RejectionReason(Prohibited)
-	return reflect.ValueOf(m)
-}
-
 var (
 	_kexInitMsg   = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
 	_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
 
-	_kexInit   = marshal(msgKexInit, _kexInitMsg)
-	_kexDHInit = marshal(msgKexDHInit, _kexDHInitMsg)
+	_kexInit   = Marshal(_kexInitMsg)
+	_kexDHInit = Marshal(_kexDHInitMsg)
 )
 
 func BenchmarkMarshalKexInitMsg(b *testing.B) {
 	for i := 0; i < b.N; i++ {
-		marshal(msgKexInit, _kexInitMsg)
+		Marshal(_kexInitMsg)
 	}
 }
 
 func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
 	m := new(kexInitMsg)
 	for i := 0; i < b.N; i++ {
-		unmarshal(m, _kexInit, msgKexInit)
+		Unmarshal(_kexInit, m)
 	}
 }
 
 func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
 	for i := 0; i < b.N; i++ {
-		marshal(msgKexDHInit, _kexDHInitMsg)
+		Marshal(_kexDHInitMsg)
 	}
 }
 
 func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
 	m := new(kexDHInitMsg)
 	for i := 0; i < b.N; i++ {
-		unmarshal(m, _kexDHInit, msgKexDHInit)
+		Unmarshal(_kexDHInit, m)
 	}
 }
diff --git a/ssh/mux.go b/ssh/mux.go
new file mode 100644
index 0000000..5af7c16
--- /dev/null
+++ b/ssh/mux.go
@@ -0,0 +1,352 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+	"log"
+	"sync"
+	"sync/atomic"
+)
+
+// debugMux, if set, causes messages in the connection protocol to be
+// logged.
+const debugMux = false
+
+// chanList is a thread safe channel list.
+type chanList struct {
+	// protects concurrent access to chans
+	sync.Mutex
+
+	// chans are indexed by the local id of the channel, which the
+	// other side should send in the PeersId field.
+	chans []*channel
+
+	// This is a debugging aid: it offsets all IDs by this
+	// amount. This helps distinguish otherwise identical
+	// server/client muxes
+	offset uint32
+}
+
+// Assigns a channel ID to the given channel.
+func (c *chanList) add(ch *channel) uint32 {
+	c.Lock()
+	defer c.Unlock()
+	for i := range c.chans {
+		if c.chans[i] == nil {
+			c.chans[i] = ch
+			return uint32(i) + c.offset
+		}
+	}
+	c.chans = append(c.chans, ch)
+	return uint32(len(c.chans)-1) + c.offset
+}
+
+// getChan returns the channel for the given ID.
+func (c *chanList) getChan(id uint32) *channel {
+	id -= c.offset
+
+	c.Lock()
+	defer c.Unlock()
+	if id < uint32(len(c.chans)) {
+		return c.chans[id]
+	}
+	return nil
+}
+
+func (c *chanList) remove(id uint32) {
+	id -= c.offset
+	c.Lock()
+	if id < uint32(len(c.chans)) {
+		c.chans[id] = nil
+	}
+	c.Unlock()
+}
+
+// dropAll forgets all channels it knows, returning them in a slice.
+func (c *chanList) dropAll() []*channel {
+	c.Lock()
+	defer c.Unlock()
+	var r []*channel
+
+	for _, ch := range c.chans {
+		if ch == nil {
+			continue
+		}
+		r = append(r, ch)
+	}
+	c.chans = nil
+	return r
+}
+
+// mux represents the state for the SSH connection protocol, which
+// multiplexes many channels onto a single packet transport.
+type mux struct {
+	conn     packetConn
+	chanList chanList
+
+	incomingChannels chan NewChannel
+
+	globalSentMu     sync.Mutex
+	globalResponses  chan interface{}
+	incomingRequests chan *Request
+
+	errCond *sync.Cond
+	err     error
+}
+
+// Each new chanList instantiation has a different offset.
+var globalOff uint32
+
+func (m *mux) Wait() error {
+	m.errCond.L.Lock()
+	defer m.errCond.L.Unlock()
+	for m.err == nil {
+		m.errCond.Wait()
+	}
+	return m.err
+}
+
+// newMux returns a mux that runs over the given connection.
+func newMux(p packetConn) *mux {
+	m := &mux{
+		conn:             p,
+		incomingChannels: make(chan NewChannel, 16),
+		globalResponses:  make(chan interface{}, 1),
+		incomingRequests: make(chan *Request, 16),
+		errCond:          newCond(),
+	}
+	m.chanList.offset = atomic.AddUint32(&globalOff, 1)
+	go m.loop()
+	return m
+}
+
+func (m *mux) sendMessage(msg interface{}) error {
+	p := Marshal(msg)
+	return m.conn.writePacket(p)
+}
+
+func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
+	if wantReply {
+		m.globalSentMu.Lock()
+		defer m.globalSentMu.Unlock()
+	}
+
+	if err := m.sendMessage(globalRequestMsg{
+		Type:      name,
+		WantReply: wantReply,
+		Data:      payload,
+	}); err != nil {
+		return false, nil, err
+	}
+
+	if !wantReply {
+		return false, nil, nil
+	}
+
+	msg, ok := <-m.globalResponses
+	if !ok {
+		return false, nil, io.EOF
+	}
+	switch msg := msg.(type) {
+	case *globalRequestFailureMsg:
+		return false, msg.Data, nil
+	case *globalRequestSuccessMsg:
+		return true, msg.Data, nil
+	default:
+		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
+	}
+}
+
+// ackRequest must be called after processing a global request that
+// has WantReply set.
+func (m *mux) ackRequest(ok bool, data []byte) error {
+	if ok {
+		return m.sendMessage(globalRequestSuccessMsg{Data: data})
+	}
+	return m.sendMessage(globalRequestFailureMsg{Data: data})
+}
+
+// TODO(hanwen): Disconnect is a transport layer message. We should
+// probably send and receive Disconnect somewhere in the transport
+// code.
+
+// Disconnect sends a disconnect message.
+func (m *mux) Disconnect(reason uint32, message string) error {
+	return m.sendMessage(disconnectMsg{
+		Reason:  reason,
+		Message: message,
+	})
+}
+
+func (m *mux) Close() error {
+	return m.conn.Close()
+}
+
+// loop runs the connection machine. It will process packets until an
+// error is encountered. To synchronize on loop exit, use mux.Wait.
+func (m *mux) loop() {
+	var err error
+	for err == nil {
+		err = m.onePacket()
+	}
+
+	for _, ch := range m.chanList.dropAll() {
+		ch.close()
+	}
+
+	close(m.incomingChannels)
+	close(m.incomingRequests)
+	close(m.globalResponses)
+
+	m.conn.Close()
+
+	m.errCond.L.Lock()
+	m.err = err
+	m.errCond.Broadcast()
+	m.errCond.L.Unlock()
+
+	if debugMux {
+		log.Println("loop exit", err)
+	}
+}
+
+// onePacket reads and processes one packet.
+func (m *mux) onePacket() error {
+	packet, err := m.conn.readPacket()
+	if err != nil {
+		return err
+	}
+
+	if debugMux {
+		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
+			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
+		} else {
+			p, _ := decode(packet)
+			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
+		}
+	}
+
+	switch packet[0] {
+	case msgNewKeys:
+		// Ignore notification of key change.
+		return nil
+	case msgDisconnect:
+		return m.handleDisconnect(packet)
+	case msgChannelOpen:
+		return m.handleChannelOpen(packet)
+	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
+		return m.handleGlobalPacket(packet)
+	}
+
+	// assume a channel packet.
+	if len(packet) < 5 {
+		return parseError(packet[0])
+	}
+	id := binary.BigEndian.Uint32(packet[1:])
+	ch := m.chanList.getChan(id)
+	if ch == nil {
+		return fmt.Errorf("ssh: invalid channel %d", id)
+	}
+
+	return ch.handlePacket(packet)
+}
+
+func (m *mux) handleDisconnect(packet []byte) error {
+	var d disconnectMsg
+	if err := Unmarshal(packet, &d); err != nil {
+		return err
+	}
+
+	if debugMux {
+		log.Printf("caught disconnect: %v", d)
+	}
+	return &d
+}
+
+func (m *mux) handleGlobalPacket(packet []byte) error {
+	msg, err := decode(packet)
+	if err != nil {
+		return err
+	}
+
+	switch msg := msg.(type) {
+	case *globalRequestMsg:
+		m.incomingRequests <- &Request{
+			Type:      msg.Type,
+			WantReply: msg.WantReply,
+			Payload:   msg.Data,
+			mux:       m,
+		}
+	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
+		m.globalResponses <- msg
+	default:
+		panic(fmt.Sprintf("not a global message %#v", msg))
+	}
+
+	return nil
+}
+
+// handleChannelOpen schedules a channel to be Accept()ed.
+func (m *mux) handleChannelOpen(packet []byte) error {
+	var msg channelOpenMsg
+	if err := Unmarshal(packet, &msg); err != nil {
+		return err
+	}
+
+	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+		failMsg := channelOpenFailureMsg{
+			PeersId:  msg.PeersId,
+			Reason:   ConnectionFailed,
+			Message:  "invalid request",
+			Language: "en_US.UTF-8",
+		}
+		return m.sendMessage(failMsg)
+	}
+
+	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
+	c.remoteId = msg.PeersId
+	c.maxRemotePayload = msg.MaxPacketSize
+	c.remoteWin.add(msg.PeersWindow)
+	m.incomingChannels <- c
+	return nil
+}
+
+func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
+	ch, err := m.openChannel(chanType, extra)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	return ch, ch.incomingRequests, nil
+}
+
+func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
+	ch := m.newChannel(chanType, channelOutbound, extra)
+
+	ch.maxIncomingPayload = channelMaxPacket
+
+	open := channelOpenMsg{
+		ChanType:         chanType,
+		PeersWindow:      ch.myWindow,
+		MaxPacketSize:    ch.maxIncomingPayload,
+		TypeSpecificData: extra,
+		PeersId:          ch.localId,
+	}
+	if err := m.sendMessage(open); err != nil {
+		return nil, err
+	}
+
+	switch msg := (<-ch.msg).(type) {
+	case *channelOpenConfirmMsg:
+		return ch, nil
+	case *channelOpenFailureMsg:
+		return nil, &OpenChannelError{msg.Reason, msg.Message}
+	default:
+		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
+	}
+}
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
new file mode 100644
index 0000000..e18afe7
--- /dev/null
+++ b/ssh/mux_test.go
@@ -0,0 +1,483 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"io"
+	"io/ioutil"
+	"sync"
+	"testing"
+)
+
+func muxPair() (*mux, *mux) {
+	a, b := memPipe()
+
+	s := newMux(a)
+	c := newMux(b)
+
+	return s, c
+}
+
+// Returns both ends of a channel, and the mux for the the 2nd
+// channel.
+func channelPair(t *testing.T) (*channel, *channel, *mux) {
+	c, s := muxPair()
+
+	res := make(chan *channel, 1)
+	go func() {
+		newCh, ok := <-s.incomingChannels
+		if !ok {
+			t.Fatalf("No incoming channel")
+		}
+		if newCh.ChannelType() != "chan" {
+			t.Fatalf("got type %q want chan", newCh.ChannelType())
+		}
+		ch, _, err := newCh.Accept()
+		if err != nil {
+			t.Fatalf("Accept %v", err)
+		}
+		res <- ch.(*channel)
+	}()
+
+	ch, err := c.openChannel("chan", nil)
+	if err != nil {
+		t.Fatalf("OpenChannel: %v", err)
+	}
+
+	return <-res, ch, c
+}
+
+func TestMuxReadWrite(t *testing.T) {
+	s, c, mux := channelPair(t)
+	defer s.Close()
+	defer c.Close()
+	defer mux.Close()
+
+	magic := "hello world"
+	magicExt := "hello stderr"
+	go func() {
+		_, err := s.Write([]byte(magic))
+		if err != nil {
+			t.Fatalf("Write: %v", err)
+		}
+		_, err = s.Extended(1).Write([]byte(magicExt))
+		if err != nil {
+			t.Fatalf("Write: %v", err)
+		}
+		err = s.Close()
+		if err != nil {
+			t.Fatalf("Close: %v", err)
+		}
+	}()
+
+	var buf [1024]byte
+	n, err := c.Read(buf[:])
+	if err != nil {
+		t.Fatalf("server Read: %v", err)
+	}
+	got := string(buf[:n])
+	if got != magic {
+		t.Fatalf("server: got %q want %q", got, magic)
+	}
+
+	n, err = c.Extended(1).Read(buf[:])
+	if err != nil {
+		t.Fatalf("server Read: %v", err)
+	}
+
+	got = string(buf[:n])
+	if got != magicExt {
+		t.Fatalf("server: got %q want %q", got, magic)
+	}
+}
+
+func TestMuxChannelOverflow(t *testing.T) {
+	reader, writer, mux := channelPair(t)
+	defer reader.Close()
+	defer writer.Close()
+	defer mux.Close()
+
+	wDone := make(chan int, 1)
+	go func() {
+		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+			t.Errorf("could not fill window: %v", err)
+		}
+		writer.Write(make([]byte, 1))
+		wDone <- 1
+	}()
+	writer.remoteWin.waitWriterBlocked()
+
+	// Send 1 byte.
+	packet := make([]byte, 1+4+4+1)
+	packet[0] = msgChannelData
+	marshalUint32(packet[1:], writer.remoteId)
+	marshalUint32(packet[5:], uint32(1))
+	packet[9] = 42
+
+	if err := writer.mux.conn.writePacket(packet); err != nil {
+		t.Errorf("could not send packet")
+	}
+	if _, err := reader.SendRequest("hello", true, nil); err == nil {
+		t.Errorf("SendRequest succeeded.")
+	}
+	<-wDone
+}
+
+func TestMuxChannelCloseWriteUnblock(t *testing.T) {
+	reader, writer, mux := channelPair(t)
+	defer reader.Close()
+	defer writer.Close()
+	defer mux.Close()
+
+	wDone := make(chan int, 1)
+	go func() {
+		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+			t.Errorf("could not fill window: %v", err)
+		}
+		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
+			t.Errorf("got %v, want EOF for unblock write", err)
+		}
+		wDone <- 1
+	}()
+
+	writer.remoteWin.waitWriterBlocked()
+	reader.Close()
+	<-wDone
+}
+
+func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
+	reader, writer, mux := channelPair(t)
+	defer reader.Close()
+	defer writer.Close()
+	defer mux.Close()
+
+	wDone := make(chan int, 1)
+	go func() {
+		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+			t.Errorf("could not fill window: %v", err)
+		}
+		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
+			t.Errorf("got %v, want EOF for unblock write", err)
+		}
+		wDone <- 1
+	}()
+
+	writer.remoteWin.waitWriterBlocked()
+	mux.Close()
+	<-wDone
+}
+
+func TestMuxReject(t *testing.T) {
+	client, server := muxPair()
+	defer server.Close()
+	defer client.Close()
+
+	go func() {
+		ch, ok := <-server.incomingChannels
+		if !ok {
+			t.Fatalf("Accept")
+		}
+		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
+			t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+		}
+		ch.Reject(RejectionReason(42), "message")
+	}()
+
+	ch, err := client.openChannel("ch", []byte("extra"))
+	if ch != nil {
+		t.Fatal("openChannel not rejected")
+	}
+
+	ocf, ok := err.(*OpenChannelError)
+	if !ok {
+		t.Errorf("got %#v want *OpenChannelError", err)
+	} else if ocf.Reason != 42 || ocf.Message != "message" {
+		t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
+	}
+
+	want := "ssh: rejected: unknown reason 42 (message)"
+	if err.Error() != want {
+		t.Errorf("got %q, want %q", err.Error(), want)
+	}
+}
+
+func TestMuxChannelRequest(t *testing.T) {
+	client, server, mux := channelPair(t)
+	defer server.Close()
+	defer client.Close()
+	defer mux.Close()
+
+	var received int
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		for r := range server.incomingRequests {
+			received++
+			r.Reply(r.Type == "yes", nil)
+		}
+		wg.Done()
+	}()
+	_, err := client.SendRequest("yes", false, nil)
+	if err != nil {
+		t.Fatalf("SendRequest: %v", err)
+	}
+	ok, err := client.SendRequest("yes", true, nil)
+	if err != nil {
+		t.Fatalf("SendRequest: %v", err)
+	}
+
+	if !ok {
+		t.Errorf("SendRequest(yes): %v", ok)
+
+	}
+
+	ok, err = client.SendRequest("no", true, nil)
+	if err != nil {
+		t.Fatalf("SendRequest: %v", err)
+	}
+	if ok {
+		t.Errorf("SendRequest(no): %v", ok)
+
+	}
+
+	client.Close()
+	wg.Wait()
+
+	if received != 3 {
+		t.Errorf("got %d requests, want %d", received, 3)
+	}
+}
+
+func TestMuxGlobalRequest(t *testing.T) {
+	clientMux, serverMux := muxPair()
+	defer serverMux.Close()
+	defer clientMux.Close()
+
+	var seen bool
+	go func() {
+		for r := range serverMux.incomingRequests {
+			seen = seen || r.Type == "peek"
+			if r.WantReply {
+				err := r.Reply(r.Type == "yes",
+					append([]byte(r.Type), r.Payload...))
+				if err != nil {
+					t.Errorf("AckRequest: %v", err)
+				}
+			}
+		}
+	}()
+
+	_, _, err := clientMux.SendRequest("peek", false, nil)
+	if err != nil {
+		t.Errorf("SendRequest: %v", err)
+	}
+
+	ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
+	if !ok || string(data) != "yesa" || err != nil {
+		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
+			ok, data, err)
+	}
+	if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
+		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
+			ok, data, err)
+	}
+
+	if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
+		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
+			ok, data, err)
+	}
+
+	clientMux.Disconnect(0, "")
+	if !seen {
+		t.Errorf("never saw 'peek' request")
+	}
+}
+
+func TestMuxGlobalRequestUnblock(t *testing.T) {
+	clientMux, serverMux := muxPair()
+	defer serverMux.Close()
+	defer clientMux.Close()
+
+	result := make(chan error, 1)
+	go func() {
+		_, _, err := clientMux.SendRequest("hello", true, nil)
+		result <- err
+	}()
+
+	<-serverMux.incomingRequests
+	serverMux.conn.Close()
+	err := <-result
+
+	if err != io.EOF {
+		t.Errorf("want EOF, got %v", io.EOF)
+	}
+}
+
+func TestMuxChannelRequestUnblock(t *testing.T) {
+	a, b, connB := channelPair(t)
+	defer a.Close()
+	defer b.Close()
+	defer connB.Close()
+
+	result := make(chan error, 1)
+	go func() {
+		_, err := a.SendRequest("hello", true, nil)
+		result <- err
+	}()
+
+	<-b.incomingRequests
+	connB.conn.Close()
+	err := <-result
+
+	if err != io.EOF {
+		t.Errorf("want EOF, got %v", err)
+	}
+}
+
+func TestMuxDisconnect(t *testing.T) {
+	a, b := muxPair()
+	defer a.Close()
+	defer b.Close()
+
+	go func() {
+		for r := range b.incomingRequests {
+			r.Reply(true, nil)
+		}
+	}()
+
+	a.Disconnect(42, "whatever")
+	ok, _, err := a.SendRequest("hello", true, nil)
+	if ok || err == nil {
+		t.Errorf("got reply after disconnecting")
+	}
+	err = b.Wait()
+	if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
+		t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
+	}
+}
+
+func TestMuxCloseChannel(t *testing.T) {
+	r, w, mux := channelPair(t)
+	defer mux.Close()
+	defer r.Close()
+	defer w.Close()
+
+	result := make(chan error, 1)
+	go func() {
+		var b [1024]byte
+		_, err := r.Read(b[:])
+		result <- err
+	}()
+	if err := w.Close(); err != nil {
+		t.Errorf("w.Close: %v", err)
+	}
+
+	if _, err := w.Write([]byte("hello")); err != io.EOF {
+		t.Errorf("got err %v, want io.EOF after Close", err)
+	}
+
+	if err := <-result; err != io.EOF {
+		t.Errorf("got %v (%T), want io.EOF", err, err)
+	}
+}
+
+func TestMuxCloseWriteChannel(t *testing.T) {
+	r, w, mux := channelPair(t)
+	defer mux.Close()
+
+	result := make(chan error, 1)
+	go func() {
+		var b [1024]byte
+		_, err := r.Read(b[:])
+		result <- err
+	}()
+	if err := w.CloseWrite(); err != nil {
+		t.Errorf("w.CloseWrite: %v", err)
+	}
+
+	if _, err := w.Write([]byte("hello")); err != io.EOF {
+		t.Errorf("got err %v, want io.EOF after CloseWrite", err)
+	}
+
+	if err := <-result; err != io.EOF {
+		t.Errorf("got %v (%T), want io.EOF", err, err)
+	}
+}
+
+func TestMuxInvalidRecord(t *testing.T) {
+	a, b := muxPair()
+	defer a.Close()
+	defer b.Close()
+
+	packet := make([]byte, 1+4+4+1)
+	packet[0] = msgChannelData
+	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
+	marshalUint32(packet[5:], 1)
+	packet[9] = 42
+
+	a.conn.writePacket(packet)
+	go a.SendRequest("hello", false, nil)
+	// 'a' wrote an invalid packet, so 'b' has exited.
+	req, ok := <-b.incomingRequests
+	if ok {
+		t.Errorf("got request %#v after receiving invalid packet", req)
+	}
+}
+
+func TestZeroWindowAdjust(t *testing.T) {
+	a, b, mux := channelPair(t)
+	defer a.Close()
+	defer b.Close()
+	defer mux.Close()
+
+	go func() {
+		io.WriteString(a, "hello")
+		// bogus adjust.
+		a.sendMessage(windowAdjustMsg{})
+		io.WriteString(a, "world")
+		a.Close()
+	}()
+
+	want := "helloworld"
+	c, _ := ioutil.ReadAll(b)
+	if string(c) != want {
+		t.Errorf("got %q want %q", c, want)
+	}
+}
+
+func TestMuxMaxPacketSize(t *testing.T) {
+	a, b, mux := channelPair(t)
+	defer a.Close()
+	defer b.Close()
+	defer mux.Close()
+
+	large := make([]byte, a.maxRemotePayload+1)
+	packet := make([]byte, 1+4+4+1+len(large))
+	packet[0] = msgChannelData
+	marshalUint32(packet[1:], a.remoteId)
+	marshalUint32(packet[5:], uint32(len(large)))
+	packet[9] = 42
+
+	if err := a.mux.conn.writePacket(packet); err != nil {
+		t.Errorf("could not send packet")
+	}
+
+	go a.SendRequest("hello", false, nil)
+
+	_, ok := <-b.incomingRequests
+	if ok {
+		t.Errorf("connection still alive after receiving large packet.")
+	}
+}
+
+// Don't ship code with debug=true.
+func TestDebug(t *testing.T) {
+	if debugMux {
+		t.Error("mux debug switched on")
+	}
+	if debugHandshake {
+		t.Error("handshake debug switched on")
+	}
+}
diff --git a/ssh/server.go b/ssh/server.go
index b4defbe..7a53d57 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -6,38 +6,52 @@
 
 import (
 	"bytes"
-	"crypto/rand"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
 	"net"
-	"sync"
-
-	_ "crypto/sha1"
 )
 
-type ServerConfig struct {
-	hostKeys []Signer
+// The Permissions type holds fine-grained permissions that are
+// specific to a user or a specific authentication method for a
+// user. Permissions, except for "source-address", must be enforced in
+// the server application layer, after successful authentication. The
+// Permissions are passed on in ServerConn so a server implementation
+// can honor them.
+type Permissions struct {
+	// Critical options restrict default permissions. Common
+	// restrictions are "source-address" and "force-command". If
+	// the server cannot enforce the restriction, or does not
+	// recognize it, the user should not authenticate.
+	CriticalOptions map[string]string
 
-	// Rand provides the source of entropy for key exchange. If Rand is
-	// nil, the cryptographic random reader in package crypto/rand will
-	// be used.
-	Rand io.Reader
+	// Extensions are extra functionality that the server may
+	// offer on authenticated connections. Common extensions are
+	// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
+	// support for an extension does not preclude authenticating a
+	// user.
+	Extensions map[string]string
+}
+
+// ServerConfig holds server specific configuration data.
+type ServerConfig struct {
+	// Config contains configuration shared between client and server.
+	Config
+
+	hostKeys []Signer
 
 	// NoClientAuth is true if clients are allowed to connect without
 	// authenticating.
 	NoClientAuth bool
 
-	// PasswordCallback, if non-nil, is called when a user attempts to
-	// authenticate using a password. It may be called concurrently from
-	// several goroutines.
-	PasswordCallback func(conn *ServerConn, user, password string) bool
+	// PasswordCallback, if non-nil, is called when a user
+	// attempts to authenticate using a password.
+	PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
 
 	// PublicKeyCallback, if non-nil, is called when a client attempts public
 	// key authentication. It must return true if the given public key is
-	// valid for the given user.
-	PublicKeyCallback func(conn *ServerConn, user, algo string, pubkey []byte) bool
+	// valid for the given user. For example, see CertChecker.Authenticate.
+	PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
 
 	// KeyboardInteractiveCallback, if non-nil, is called when
 	// keyboard-interactive authentication is selected (RFC
@@ -46,24 +60,19 @@
 	// Challenge rounds. To avoid information leaks, the client
 	// should be presented a challenge even if the user is
 	// unknown.
-	KeyboardInteractiveCallback func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool
+	KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
 
-	// Cryptographic-related configuration.
-	Crypto CryptoConfig
-}
-
-func (c *ServerConfig) rand() io.Reader {
-	if c.Rand == nil {
-		return rand.Reader
-	}
-	return c.Rand
+	// AuthLogCallback, if non-nil, is called to log all authentication
+	// attempts.
+	AuthLogCallback func(conn ConnMetadata, method string, err error)
 }
 
 // AddHostKey adds a private key as a host key. If an existing host
-// key exists with the same algorithm, it is overwritten.
+// key exists with the same algorithm, it is overwritten. Each server
+// config must have at least one host key.
 func (s *ServerConfig) AddHostKey(key Signer) {
 	for i, k := range s.hostKeys {
-		if k.PublicKey().PublicKeyAlgo() == key.PublicKey().PublicKeyAlgo() {
+		if k.PublicKey().Type() == key.PublicKey().Type() {
 			s.hostKeys[i] = key
 			return
 		}
@@ -72,68 +81,73 @@
 	s.hostKeys = append(s.hostKeys, key)
 }
 
-// SetRSAPrivateKey sets the private key for a Server. A Server must have a
-// private key configured in order to accept connections. The private key must
-// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
-// typically contains such a key.
-func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error {
-	priv, err := ParsePrivateKey(pemBytes)
-	if err != nil {
-		return err
-	}
-	s.AddHostKey(priv)
-	return nil
+// cachedPubKey contains the results of querying whether a public key is
+// acceptable for a user.
+type cachedPubKey struct {
+	user       string
+	pubKeyData []byte
+	result     error
+	perms      *Permissions
 }
 
-// cachedPubKey contains the results of querying whether a public key is
-// acceptable for a user. The cache only applies to a single ServerConn.
-type cachedPubKey struct {
-	user, algo string
-	pubKey     []byte
-	result     bool
+func (k1 *cachedPubKey) Equal(k2 *cachedPubKey) bool {
+	return k1.user == k2.user && bytes.Equal(k1.pubKeyData, k2.pubKeyData)
 }
 
 const maxCachedPubKeys = 16
 
-// A ServerConn represents an incoming connection.
-type ServerConn struct {
-	transport *transport
-	config    *ServerConfig
-
-	channels   map[uint32]*serverChan
-	nextChanId uint32
-
-	// lock protects err and channels.
-	lock sync.Mutex
-	err  error
-
-	// cachedPubKeys contains the cache results of tests for public keys.
-	// Since SSH clients will query whether a public key is acceptable
-	// before attempting to authenticate with it, we end up with duplicate
-	// queries for public key validity.
-	cachedPubKeys []cachedPubKey
-
-	// User holds the successfully authenticated user name.
-	// It is empty if no authentication is used.  It is populated before
-	// any authentication callback is called and not assigned to after that.
-	User string
-
-	// ClientVersion is the client's version, populated after
-	// Handshake is called. It should not be modified.
-	ClientVersion []byte
-
-	// Our version.
-	serverVersion []byte
+// pubKeyCache caches tests for public keys.  Since SSH clients
+// will query whether a public key is acceptable before attempting to
+// authenticate with it, we end up with duplicate queries for public
+// key validity.  The cache only applies to a single ServerConn.
+type pubKeyCache struct {
+	keys []cachedPubKey
 }
 
-// Server returns a new SSH server connection
-// using c as the underlying transport.
-func Server(c net.Conn, config *ServerConfig) *ServerConn {
-	return &ServerConn{
-		transport: newTransport(c, config.rand(), false /* not client */),
-		channels:  make(map[uint32]*serverChan),
-		config:    config,
+// get returns the result for a given user/algo/key tuple.
+func (c *pubKeyCache) get(candidate cachedPubKey) (result error, ok bool) {
+	for _, k := range c.keys {
+		if k.Equal(&candidate) {
+			return k.result, true
+		}
 	}
+	return errors.New("ssh: not in cache"), false
+}
+
+// add adds the given tuple to the cache.
+func (c *pubKeyCache) add(candidate cachedPubKey) {
+	if len(c.keys) < maxCachedPubKeys {
+		c.keys = append(c.keys, candidate)
+	}
+}
+
+// ServerConn is an authenticated SSH connection, as seen from the
+// server
+type ServerConn struct {
+	Conn
+
+	// If the succeeding authentication callback returned a
+	// non-nil Permissions pointer, it is stored here.
+	Permissions *Permissions
+}
+
+// NewServerConn starts a new SSH server with c as the underlying
+// transport.  It starts with a handshake and, if the handshake is
+// unsuccessful, it closes the connection and returns an error.  The
+// Request and NewChannel channels must be serviced, or the connection
+// will hang.
+func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
+	fullConf := *config
+	fullConf.SetDefaults()
+	s := &connection{
+		sshConn: sshConn{conn: c},
+	}
+	perms, err := s.serverHandshake(&fullConf)
+	if err != nil {
+		c.Close()
+		return nil, nil, nil, err
+	}
+	return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
 }
 
 // signAndMarshal signs the data with the appropriate algorithm,
@@ -144,134 +158,60 @@
 		return nil, err
 	}
 
-	return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
+	return Marshal(sig), nil
 }
 
-// Close closes the connection.
-func (s *ServerConn) Close() error { return s.transport.Close() }
+// handshake performs key exchange and user authentication.
+func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
+	if len(config.hostKeys) == 0 {
+		return nil, errors.New("ssh: server has no host keys")
+	}
 
-// LocalAddr returns the local network address.
-func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
-
-// RemoteAddr returns the remote network address.
-func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
-
-// Handshake performs an SSH transport and client authentication on the given ServerConn.
-func (s *ServerConn) Handshake() error {
 	var err error
 	s.serverVersion = []byte(packageVersion)
-	s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersion)
+	s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
 	if err != nil {
-		return err
+		return nil, err
 	}
-	if err := s.clientInitHandshake(nil, nil); err != nil {
-		return err
+
+	tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
+	s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
+
+	if err := s.transport.requestKeyChange(); err != nil {
+		return nil, err
+	}
+
+	if packet, err := s.transport.readPacket(); err != nil {
+		return nil, err
+	} else if packet[0] != msgNewKeys {
+		return nil, unexpectedMessageError(msgNewKeys, packet[0])
 	}
 
 	var packet []byte
 	if packet, err = s.transport.readPacket(); err != nil {
-		return err
+		return nil, err
 	}
+
 	var serviceRequest serviceRequestMsg
-	if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
-		return err
+	if err = Unmarshal(packet, &serviceRequest); err != nil {
+		return nil, err
 	}
 	if serviceRequest.Service != serviceUserAuth {
-		return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
+		return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
 	}
 	serviceAccept := serviceAcceptMsg{
 		Service: serviceUserAuth,
 	}
-	if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
-		return err
+	if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
+		return nil, err
 	}
 
-	if err := s.authenticate(); err != nil {
-		return err
-	}
-	return err
-}
-
-func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
-	serverKexInit := kexInitMsg{
-		KexAlgos:                s.config.Crypto.kexes(),
-		CiphersClientServer:     s.config.Crypto.ciphers(),
-		CiphersServerClient:     s.config.Crypto.ciphers(),
-		MACsClientServer:        s.config.Crypto.macs(),
-		MACsServerClient:        s.config.Crypto.macs(),
-		CompressionClientServer: supportedCompressions,
-		CompressionServerClient: supportedCompressions,
-	}
-	for _, k := range s.config.hostKeys {
-		serverKexInit.ServerHostKeyAlgos = append(
-			serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKeyAlgo())
-	}
-
-	serverKexInitPacket := marshal(msgKexInit, serverKexInit)
-	if err = s.transport.writePacket(serverKexInitPacket); err != nil {
-		return
-	}
-
-	if clientKexInitPacket == nil {
-		clientKexInit = new(kexInitMsg)
-		if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
-			return
-		}
-		if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
-			return
-		}
-	}
-
-	algs := findAgreedAlgorithms(clientKexInit, &serverKexInit)
-	if algs == nil {
-		return errors.New("ssh: no common algorithms")
-	}
-
-	if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
-		// The client sent a Kex message for the wrong algorithm,
-		// which we have to ignore.
-		if _, err = s.transport.readPacket(); err != nil {
-			return
-		}
-	}
-
-	var hostKey Signer
-	for _, k := range s.config.hostKeys {
-		if algs.hostKey == k.PublicKey().PublicKeyAlgo() {
-			hostKey = k
-		}
-	}
-
-	kex, ok := kexAlgoMap[algs.kex]
-	if !ok {
-		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
-	}
-
-	magics := handshakeMagics{
-		serverVersion: s.serverVersion,
-		clientVersion: s.ClientVersion,
-		serverKexInit: marshal(msgKexInit, serverKexInit),
-		clientKexInit: clientKexInitPacket,
-	}
-	result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
+	perms, err := s.serverAuthenticate(config)
 	if err != nil {
-		return err
+		return nil, err
 	}
-
-	if err = s.transport.prepareKeyChange(algs, result); err != nil {
-		return err
-	}
-
-	if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
-		return
-	}
-	if packet, err := s.transport.readPacket(); err != nil {
-		return err
-	} else if packet[0] != msgNewKeys {
-		return UnexpectedMessageError{msgNewKeys, packet[0]}
-	}
-
-	return
+	s.mux = newMux(s.transport)
+	return perms, err
 }
 
 func isAcceptableAlgo(algo string) bool {
@@ -283,181 +223,213 @@
 	return false
 }
 
-// testPubKey returns true if the given public key is acceptable for the user.
-func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
-	if s.config.PublicKeyCallback == nil || !isAcceptableAlgo(algo) {
-		return false
+func checkSourceAddress(addr net.Addr, sourceAddr string) error {
+	if addr == nil {
+		return errors.New("ssh: no address known for client, but source-address match required")
 	}
 
-	for _, c := range s.cachedPubKeys {
-		if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
-			return c.result
+	tcpAddr, ok := addr.(*net.TCPAddr)
+	if !ok {
+		return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
+	}
+
+	if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
+		if bytes.Equal(allowedIP, tcpAddr.IP) {
+			return nil
+		}
+	} else {
+		_, ipNet, err := net.ParseCIDR(sourceAddr)
+		if err != nil {
+			return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
+		}
+
+		if ipNet.Contains(tcpAddr.IP) {
+			return nil
 		}
 	}
 
-	result := s.config.PublicKeyCallback(s, user, algo, pubKey)
-	if len(s.cachedPubKeys) < maxCachedPubKeys {
-		c := cachedPubKey{
-			user:   user,
-			algo:   algo,
-			pubKey: make([]byte, len(pubKey)),
-			result: result,
-		}
-		copy(c.pubKey, pubKey)
-		s.cachedPubKeys = append(s.cachedPubKeys, c)
-	}
-
-	return result
+	return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
 }
 
-func (s *ServerConn) authenticate() error {
-	var userAuthReq userAuthRequestMsg
+func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
 	var err error
-	var packet []byte
+	var cache pubKeyCache
+	var perms *Permissions
 
 userAuthLoop:
 	for {
-		if packet, err = s.transport.readPacket(); err != nil {
-			return err
-		}
-		if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
-			return err
+		var userAuthReq userAuthRequestMsg
+		if packet, err := s.transport.readPacket(); err != nil {
+			return nil, err
+		} else if err = Unmarshal(packet, &userAuthReq); err != nil {
+			return nil, err
 		}
 
 		if userAuthReq.Service != serviceSSH {
-			return errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
+			return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
 		}
 
+		s.user = userAuthReq.User
+		perms = nil
+		authErr := errors.New("no auth passed yet")
+
 		switch userAuthReq.Method {
 		case "none":
-			if s.config.NoClientAuth {
-				break userAuthLoop
+			if config.NoClientAuth {
+				s.user = ""
+				authErr = nil
 			}
 		case "password":
-			if s.config.PasswordCallback == nil {
+			if config.PasswordCallback == nil {
+				authErr = errors.New("ssh: password auth not configured")
 				break
 			}
 			payload := userAuthReq.Payload
 			if len(payload) < 1 || payload[0] != 0 {
-				return ParseError{msgUserAuthRequest}
+				return nil, parseError(msgUserAuthRequest)
 			}
 			payload = payload[1:]
 			password, payload, ok := parseString(payload)
 			if !ok || len(payload) > 0 {
-				return ParseError{msgUserAuthRequest}
+				return nil, parseError(msgUserAuthRequest)
 			}
 
-			s.User = userAuthReq.User
-			if s.config.PasswordCallback(s, userAuthReq.User, string(password)) {
-				break userAuthLoop
-			}
+			perms, authErr = config.PasswordCallback(s, password)
 		case "keyboard-interactive":
-			if s.config.KeyboardInteractiveCallback == nil {
+			if config.KeyboardInteractiveCallback == nil {
+				authErr = errors.New("ssh: keyboard-interactive auth not configubred")
 				break
 			}
 
-			s.User = userAuthReq.User
-			if s.config.KeyboardInteractiveCallback(s, s.User, &sshClientKeyboardInteractive{s}) {
-				break userAuthLoop
-			}
+			prompter := &sshClientKeyboardInteractive{s}
+			perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
 		case "publickey":
-			if s.config.PublicKeyCallback == nil {
+			if config.PublicKeyCallback == nil {
+				authErr = errors.New("ssh: publickey auth not configured")
 				break
 			}
 			payload := userAuthReq.Payload
 			if len(payload) < 1 {
-				return ParseError{msgUserAuthRequest}
+				return nil, parseError(msgUserAuthRequest)
 			}
 			isQuery := payload[0] == 0
 			payload = payload[1:]
 			algoBytes, payload, ok := parseString(payload)
 			if !ok {
-				return ParseError{msgUserAuthRequest}
+				return nil, parseError(msgUserAuthRequest)
 			}
 			algo := string(algoBytes)
-
-			pubKey, payload, ok := parseString(payload)
-			if !ok {
-				return ParseError{msgUserAuthRequest}
+			if !isAcceptableAlgo(algo) {
+				authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
+				break
 			}
+
+			pubKeyData, payload, ok := parseString(payload)
+			if !ok {
+				return nil, parseError(msgUserAuthRequest)
+			}
+
+			pubKey, err := ParsePublicKey(pubKeyData)
+			if err != nil {
+				return nil, err
+			}
+			candidate := cachedPubKey{
+				user:       s.user,
+				pubKeyData: pubKeyData,
+			}
+			candidate.result, ok = cache.get(candidate)
+			if !ok {
+				candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
+				if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
+					candidate.result = checkSourceAddress(
+						s.RemoteAddr(),
+						candidate.perms.CriticalOptions[sourceAddressCriticalOption])
+				}
+				cache.add(candidate)
+			}
+
 			if isQuery {
 				// The client can query if the given public key
 				// would be okay.
 				if len(payload) > 0 {
-					return ParseError{msgUserAuthRequest}
+					return nil, parseError(msgUserAuthRequest)
 				}
-				if s.testPubKey(userAuthReq.User, algo, pubKey) {
+
+				if candidate.result == nil {
 					okMsg := userAuthPubKeyOkMsg{
 						Algo:   algo,
-						PubKey: string(pubKey),
+						PubKey: pubKeyData,
 					}
-					if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
-						return err
+					if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
+						return nil, err
 					}
 					continue userAuthLoop
 				}
+				authErr = candidate.result
 			} else {
 				sig, payload, ok := parseSignature(payload)
 				if !ok || len(payload) > 0 {
-					return ParseError{msgUserAuthRequest}
+					return nil, parseError(msgUserAuthRequest)
 				}
 				// Ensure the public key algo and signature algo
 				// are supported.  Compare the private key
 				// algorithm name that corresponds to algo with
 				// sig.Format.  This is usually the same, but
 				// for certs, the names differ.
-				if !isAcceptableAlgo(algo) || !isAcceptableAlgo(sig.Format) || pubAlgoToPrivAlgo(algo) != sig.Format {
+				if !isAcceptableAlgo(sig.Format) {
 					break
 				}
-				signedData := buildDataSignedForAuth(s.transport.sessionID, userAuthReq, algoBytes, pubKey)
-				key, _, ok := ParsePublicKey(pubKey)
-				if !ok {
-					return ParseError{msgUserAuthRequest}
+				signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
+
+				if err := pubKey.Verify(signedData, sig); err != nil {
+					return nil, err
 				}
 
-				if !key.Verify(signedData, sig.Blob) {
-					return ParseError{msgUserAuthRequest}
-				}
-				// TODO(jmpittman): Implement full validation for certificates.
-				s.User = userAuthReq.User
-				if s.testPubKey(userAuthReq.User, algo, pubKey) {
-					break userAuthLoop
-				}
+				authErr = candidate.result
+				perms = candidate.perms
 			}
+		default:
+			authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
+		}
+
+		if config.AuthLogCallback != nil {
+			config.AuthLogCallback(s, userAuthReq.Method, authErr)
+		}
+
+		if authErr == nil {
+			break userAuthLoop
 		}
 
 		var failureMsg userAuthFailureMsg
-		if s.config.PasswordCallback != nil {
+		if config.PasswordCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "password")
 		}
-		if s.config.PublicKeyCallback != nil {
+		if config.PublicKeyCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "publickey")
 		}
-		if s.config.KeyboardInteractiveCallback != nil {
+		if config.KeyboardInteractiveCallback != nil {
 			failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
 		}
 
 		if len(failureMsg.Methods) == 0 {
-			return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
+			return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
 		}
 
-		if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
-			return err
+		if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
+			return nil, err
 		}
 	}
 
-	packet = []byte{msgUserAuthSuccess}
-	if err = s.transport.writePacket(packet); err != nil {
-		return err
+	if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
+		return nil, err
 	}
-
-	return nil
+	return perms, nil
 }
 
 // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
 // asking the client on the other side of a ServerConn.
 type sshClientKeyboardInteractive struct {
-	*ServerConn
+	*connection
 }
 
 func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
@@ -471,7 +443,7 @@
 		prompts = appendBool(prompts, echos[i])
 	}
 
-	if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
+	if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
 		Instruction: instruction,
 		NumPrompts:  uint32(len(questions)),
 		Prompts:     prompts,
@@ -484,19 +456,19 @@
 		return nil, err
 	}
 	if packet[0] != msgUserAuthInfoResponse {
-		return nil, UnexpectedMessageError{msgUserAuthInfoResponse, packet[0]}
+		return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
 	}
 	packet = packet[1:]
 
 	n, packet, ok := parseUint32(packet)
 	if !ok || int(n) != len(questions) {
-		return nil, &ParseError{msgUserAuthInfoResponse}
+		return nil, parseError(msgUserAuthInfoResponse)
 	}
 
 	for i := uint32(0); i < n; i++ {
 		ans, rest, ok := parseString(packet)
 		if !ok {
-			return nil, &ParseError{msgUserAuthInfoResponse}
+			return nil, parseError(msgUserAuthInfoResponse)
 		}
 
 		answers = append(answers, string(ans))
@@ -508,185 +480,3 @@
 
 	return answers, nil
 }
-
-const defaultWindowSize = 32768
-
-// Accept reads and processes messages on a ServerConn. It must be called
-// in order to demultiplex messages to any resulting Channels.
-func (s *ServerConn) Accept() (Channel, error) {
-	// TODO(dfc) s.lock is not held here so visibility of s.err is not guaranteed.
-	if s.err != nil {
-		return nil, s.err
-	}
-
-	for {
-		packet, err := s.transport.readPacket()
-		if err != nil {
-
-			s.lock.Lock()
-			s.err = err
-			s.lock.Unlock()
-
-			// TODO(dfc) s.lock protects s.channels but isn't being held here.
-			for _, c := range s.channels {
-				c.setDead()
-				c.handleData(nil)
-			}
-
-			return nil, err
-		}
-
-		switch packet[0] {
-		case msgChannelData:
-			if len(packet) < 9 {
-				// malformed data packet
-				return nil, ParseError{msgChannelData}
-			}
-			remoteId := binary.BigEndian.Uint32(packet[1:5])
-			s.lock.Lock()
-			c, ok := s.channels[remoteId]
-			if !ok {
-				s.lock.Unlock()
-				continue
-			}
-			if length := binary.BigEndian.Uint32(packet[5:9]); length > 0 {
-				packet = packet[9:]
-				c.handleData(packet[:length])
-			}
-			s.lock.Unlock()
-		default:
-			decoded, err := decode(packet)
-			if err != nil {
-				return nil, err
-			}
-			switch msg := decoded.(type) {
-			case *channelOpenMsg:
-				if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
-					return nil, errors.New("ssh: invalid MaxPacketSize from peer")
-				}
-				c := &serverChan{
-					channel: channel{
-						packetConn: s.transport,
-						remoteId:   msg.PeersId,
-						remoteWin:  window{Cond: newCond()},
-						maxPacket:  msg.MaxPacketSize,
-					},
-					chanType:    msg.ChanType,
-					extraData:   msg.TypeSpecificData,
-					myWindow:    defaultWindowSize,
-					serverConn:  s,
-					cond:        newCond(),
-					pendingData: make([]byte, defaultWindowSize),
-				}
-				c.remoteWin.add(msg.PeersWindow)
-				s.lock.Lock()
-				c.localId = s.nextChanId
-				s.nextChanId++
-				s.channels[c.localId] = c
-				s.lock.Unlock()
-				return c, nil
-
-			case *channelRequestMsg:
-				s.lock.Lock()
-				c, ok := s.channels[msg.PeersId]
-				if !ok {
-					s.lock.Unlock()
-					continue
-				}
-				c.handlePacket(msg)
-				s.lock.Unlock()
-
-			case *windowAdjustMsg:
-				s.lock.Lock()
-				c, ok := s.channels[msg.PeersId]
-				if !ok {
-					s.lock.Unlock()
-					continue
-				}
-				c.handlePacket(msg)
-				s.lock.Unlock()
-
-			case *channelEOFMsg:
-				s.lock.Lock()
-				c, ok := s.channels[msg.PeersId]
-				if !ok {
-					s.lock.Unlock()
-					continue
-				}
-				c.handlePacket(msg)
-				s.lock.Unlock()
-
-			case *channelCloseMsg:
-				s.lock.Lock()
-				c, ok := s.channels[msg.PeersId]
-				if !ok {
-					s.lock.Unlock()
-					continue
-				}
-				c.handlePacket(msg)
-				s.lock.Unlock()
-
-			case *globalRequestMsg:
-				if msg.WantReply {
-					if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
-						return nil, err
-					}
-				}
-
-			case *kexInitMsg:
-				s.lock.Lock()
-				if err := s.clientInitHandshake(msg, packet); err != nil {
-					s.lock.Unlock()
-					return nil, err
-				}
-				s.lock.Unlock()
-			case *disconnectMsg:
-				return nil, io.EOF
-			default:
-				// Unknown message. Ignore.
-			}
-		}
-	}
-
-	panic("unreachable")
-}
-
-// A Listener implements a network listener (net.Listener) for SSH connections.
-type Listener struct {
-	listener net.Listener
-	config   *ServerConfig
-}
-
-// Addr returns the listener's network address.
-func (l *Listener) Addr() net.Addr {
-	return l.listener.Addr()
-}
-
-// Close closes the listener.
-func (l *Listener) Close() error {
-	return l.listener.Close()
-}
-
-// Accept waits for and returns the next incoming SSH connection.
-// The receiver should call Handshake() in another goroutine
-// to avoid blocking the accepter.
-func (l *Listener) Accept() (*ServerConn, error) {
-	c, err := l.listener.Accept()
-	if err != nil {
-		return nil, err
-	}
-	return Server(c, l.config), nil
-}
-
-// Listen creates an SSH listener accepting connections on
-// the given network address using net.Listen.
-func Listen(network, addr string, config *ServerConfig) (*Listener, error) {
-	l, err := net.Listen(network, addr)
-	if err != nil {
-		return nil, err
-	}
-	return &Listener{
-		l,
-		config,
-	}, nil
-}
diff --git a/ssh/session.go b/ssh/session.go
index 39f2d22..3b42b50 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -129,128 +129,126 @@
 	Stdout io.Writer
 	Stderr io.Writer
 
-	*clientChan // the channel backing this session
-
-	started   bool // true once Start, Run or Shell is invoked.
+	ch        Channel // the channel backing this session
+	started   bool    // true once Start, Run or Shell is invoked.
 	copyFuncs []func() error
 	errors    chan error // one send per copyFunc
 
 	// true if pipe method is active
 	stdinpipe, stdoutpipe, stderrpipe bool
+
+	// stdinPipeWriter is non-nil if StdinPipe has not been called
+	// and Stdin was specified by the user; it is the write end of
+	// a pipe connecting Session.Stdin to the stdin channel.
+	stdinPipeWriter io.WriteCloser
+
+	exitStatus chan error
+}
+
+// SendRequest sends an out-of-band channel request on the SSH channel
+// underlying the session.
+func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+	return s.ch.SendRequest(name, wantReply, payload)
+}
+
+func (s *Session) Close() error {
+	return s.ch.Close()
 }
 
 // RFC 4254 Section 6.4.
 type setenvRequest struct {
-	PeersId   uint32
-	Request   string
-	WantReply bool
-	Name      string
-	Value     string
-}
-
-// RFC 4254 Section 6.5.
-type subsystemRequestMsg struct {
-	PeersId   uint32
-	Request   string
-	WantReply bool
-	Subsystem string
+	Name  string
+	Value string
 }
 
 // Setenv sets an environment variable that will be applied to any
 // command executed by Shell or Run.
 func (s *Session) Setenv(name, value string) error {
-	req := setenvRequest{
-		PeersId:   s.remoteId,
-		Request:   "env",
-		WantReply: true,
-		Name:      name,
-		Value:     value,
+	msg := setenvRequest{
+		Name:  name,
+		Value: value,
 	}
-	if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
-		return err
+	ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
+	if err == nil && !ok {
+		err = errors.New("ssh: setenv failed")
 	}
-	return s.waitForResponse()
+	return err
 }
 
 // RFC 4254 Section 6.2.
 type ptyRequestMsg struct {
-	PeersId   uint32
-	Request   string
-	WantReply bool
-	Term      string
-	Columns   uint32
-	Rows      uint32
-	Width     uint32
-	Height    uint32
-	Modelist  string
+	Term     string
+	Columns  uint32
+	Rows     uint32
+	Width    uint32
+	Height   uint32
+	Modelist string
 }
 
 // RequestPty requests the association of a pty with the session on the remote host.
 func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
 	var tm []byte
 	for k, v := range termmodes {
-		tm = append(tm, k)
-		tm = appendU32(tm, v)
+		kv := struct {
+			Key byte
+			Val uint32
+		}{k, v}
+
+		tm = append(tm, Marshal(&kv)...)
 	}
 	tm = append(tm, tty_OP_END)
 	req := ptyRequestMsg{
-		PeersId:   s.remoteId,
-		Request:   "pty-req",
-		WantReply: true,
-		Term:      term,
-		Columns:   uint32(w),
-		Rows:      uint32(h),
-		Width:     uint32(w * 8),
-		Height:    uint32(h * 8),
-		Modelist:  string(tm),
+		Term:     term,
+		Columns:  uint32(w),
+		Rows:     uint32(h),
+		Width:    uint32(w * 8),
+		Height:   uint32(h * 8),
+		Modelist: string(tm),
 	}
-	if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
-		return err
+	ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
+	if err == nil && !ok {
+		err = errors.New("ssh: pty-req failed")
 	}
-	return s.waitForResponse()
+	return err
+}
+
+// RFC 4254 Section 6.5.
+type subsystemRequestMsg struct {
+	Subsystem string
 }
 
 // RequestSubsystem requests the association of a subsystem with the session on the remote host.
 // A subsystem is a predefined command that runs in the background when the ssh session is initiated
 func (s *Session) RequestSubsystem(subsystem string) error {
-	req := subsystemRequestMsg{
-		PeersId:   s.remoteId,
-		Request:   "subsystem",
-		WantReply: true,
+	msg := subsystemRequestMsg{
 		Subsystem: subsystem,
 	}
-	if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
-		return err
+	ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
+	if err == nil && !ok {
+		err = errors.New("ssh: subsystem request failed")
 	}
-	return s.waitForResponse()
+	return err
 }
 
 // RFC 4254 Section 6.9.
 type signalMsg struct {
-	PeersId   uint32
-	Request   string
-	WantReply bool
-	Signal    string
+	Signal string
 }
 
 // Signal sends the given signal to the remote process.
 // sig is one of the SIG* constants.
 func (s *Session) Signal(sig Signal) error {
-	req := signalMsg{
-		PeersId:   s.remoteId,
-		Request:   "signal",
-		WantReply: false,
-		Signal:    string(sig),
+	msg := signalMsg{
+		Signal: string(sig),
 	}
-	return s.writePacket(marshal(msgChannelRequest, req))
+
+	_, err := s.ch.SendRequest("signal", false, Marshal(&msg))
+	return err
 }
 
 // RFC 4254 Section 6.5.
 type execMsg struct {
-	PeersId   uint32
-	Request   string
-	WantReply bool
-	Command   string
+	Command string
 }
 
 // Start runs cmd on the remote host. Typically, the remote
@@ -261,17 +259,16 @@
 		return errors.New("ssh: session already started")
 	}
 	req := execMsg{
-		PeersId:   s.remoteId,
-		Request:   "exec",
-		WantReply: true,
-		Command:   cmd,
+		Command: cmd,
 	}
-	if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+
+	ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
+	if err == nil && !ok {
+		err = fmt.Errorf("ssh: command %v failed", cmd)
+	}
+	if err != nil {
 		return err
 	}
-	if err := s.waitForResponse(); err != nil {
-		return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err)
-	}
 	return s.start()
 }
 
@@ -339,31 +336,17 @@
 	if s.started {
 		return errors.New("ssh: session already started")
 	}
-	req := channelRequestMsg{
-		PeersId:   s.remoteId,
-		Request:   "shell",
-		WantReply: true,
+
+	ok, err := s.ch.SendRequest("shell", true, nil)
+	if err == nil && !ok {
+		return fmt.Errorf("ssh: cound not start shell")
 	}
-	if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+	if err != nil {
 		return err
 	}
-	if err := s.waitForResponse(); err != nil {
-		return fmt.Errorf("ssh: could not execute shell: %v", err)
-	}
 	return s.start()
 }
 
-func (s *Session) waitForResponse() error {
-	msg := <-s.msg
-	switch msg.(type) {
-	case *channelRequestSuccessMsg:
-		return nil
-	case *channelRequestFailureMsg:
-		return errors.New("ssh: request failed")
-	}
-	return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg)
-}
-
 func (s *Session) start() error {
 	s.started = true
 
@@ -394,8 +377,11 @@
 	if !s.started {
 		return errors.New("ssh: session not started")
 	}
-	waitErr := s.wait()
+	waitErr := <-s.exitStatus
 
+	if s.stdinPipeWriter != nil {
+		s.stdinPipeWriter.Close()
+	}
 	var copyError error
 	for _ = range s.copyFuncs {
 		if err := <-s.errors; err != nil && copyError == nil {
@@ -408,52 +394,35 @@
 	return copyError
 }
 
-func (s *Session) wait() error {
+func (s *Session) wait(reqs <-chan *Request) error {
 	wm := Waitmsg{status: -1}
-
 	// Wait for msg channel to be closed before returning.
-	for msg := range s.msg {
-		switch msg := msg.(type) {
-		case *channelRequestMsg:
-			switch msg.Request {
-			case "exit-status":
-				d := msg.RequestSpecificData
-				wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
-			case "exit-signal":
-				signal, rest, ok := parseString(msg.RequestSpecificData)
-				if !ok {
-					return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
-				}
-				wm.signal = safeString(string(signal))
-
-				// skip coreDumped bool
-				if len(rest) == 0 {
-					return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
-				}
-				rest = rest[1:]
-
-				errmsg, rest, ok := parseString(rest)
-				if !ok {
-					return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
-				}
-				wm.msg = safeString(string(errmsg))
-
-				lang, _, ok := parseString(rest)
-				if !ok {
-					return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
-				}
-				wm.lang = safeString(string(lang))
-			default:
-				// This handles keepalives and matches
-				// OpenSSH's behaviour.
-				if msg.WantReply {
-					s.writePacket(marshal(msgChannelFailure, channelRequestFailureMsg{
-						PeersId: s.remoteId,
-					}))
-				}
+	for msg := range reqs {
+		switch msg.Type {
+		case "exit-status":
+			d := msg.Payload
+			wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
+		case "exit-signal":
+			var sigval struct {
+				Signal     string
+				CoreDumped bool
+				Error      string
+				Lang       string
 			}
+			if err := Unmarshal(msg.Payload, &sigval); err != nil {
+				return err
+			}
+
+			// Must sanitize strings?
+			wm.signal = sigval.Signal
+			wm.msg = sigval.Error
+			wm.lang = sigval.Lang
 		default:
-			return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg)
+			// This handles keepalives and matches
+			// OpenSSH's behaviour.
+			if msg.WantReply {
+				msg.Reply(false, nil)
+			}
 		}
 	}
 	if wm.status == 0 {
@@ -476,12 +445,20 @@
 	if s.stdinpipe {
 		return
 	}
+	var stdin io.Reader
 	if s.Stdin == nil {
-		s.Stdin = new(bytes.Buffer)
+		stdin = new(bytes.Buffer)
+	} else {
+		r, w := io.Pipe()
+		go func() {
+			_, err := io.Copy(w, s.Stdin)
+			w.CloseWithError(err)
+		}()
+		stdin, s.stdinPipeWriter = r, w
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
-		_, err := io.Copy(s.clientChan.stdin, s.Stdin)
-		if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF {
+		_, err := io.Copy(s.ch, stdin)
+		if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
 			err = err1
 		}
 		return err
@@ -496,7 +473,7 @@
 		s.Stdout = ioutil.Discard
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
-		_, err := io.Copy(s.Stdout, s.clientChan.stdout)
+		_, err := io.Copy(s.Stdout, s.ch)
 		return err
 	})
 }
@@ -509,11 +486,21 @@
 		s.Stderr = ioutil.Discard
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
-		_, err := io.Copy(s.Stderr, s.clientChan.stderr)
+		_, err := io.Copy(s.Stderr, s.ch.Stderr())
 		return err
 	})
 }
 
+// sessionStdin reroutes Close to CloseWrite.
+type sessionStdin struct {
+	io.Writer
+	ch Channel
+}
+
+func (s *sessionStdin) Close() error {
+	return s.ch.CloseWrite()
+}
+
 // StdinPipe returns a pipe that will be connected to the
 // remote command's standard input when the command starts.
 func (s *Session) StdinPipe() (io.WriteCloser, error) {
@@ -524,7 +511,7 @@
 		return nil, errors.New("ssh: StdinPipe after process started")
 	}
 	s.stdinpipe = true
-	return s.clientChan.stdin, nil
+	return &sessionStdin{s.ch, s.ch}, nil
 }
 
 // StdoutPipe returns a pipe that will be connected to the
@@ -541,7 +528,7 @@
 		return nil, errors.New("ssh: StdoutPipe after process started")
 	}
 	s.stdoutpipe = true
-	return s.clientChan.stdout, nil
+	return s.ch, nil
 }
 
 // StderrPipe returns a pipe that will be connected to the
@@ -558,28 +545,20 @@
 		return nil, errors.New("ssh: StderrPipe after process started")
 	}
 	s.stderrpipe = true
-	return s.clientChan.stderr, nil
+	return s.ch.Stderr(), nil
 }
 
-// NewSession returns a new interactive session on the remote host.
-func (c *ClientConn) NewSession() (*Session, error) {
-	ch := c.newChan(c.transport)
-	if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{
-		ChanType:      "session",
-		PeersId:       ch.localId,
-		PeersWindow:   channelWindowSize,
-		MaxPacketSize: channelMaxPacketSize,
-	})); err != nil {
-		c.chanList.remove(ch.localId)
-		return nil, err
+// newSession returns a new interactive session on the remote host.
+func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
+	s := &Session{
+		ch: ch,
 	}
-	if err := ch.waitForChannelOpenResponse(); err != nil {
-		c.chanList.remove(ch.localId)
-		return nil, fmt.Errorf("ssh: unable to open session: %v", err)
-	}
-	return &Session{
-		clientChan: ch,
-	}, nil
+	s.exitStatus = make(chan error, 1)
+	go func() {
+		s.exitStatus <- s.wait(reqs)
+	}()
+
+	return s, nil
 }
 
 // An ExitError reports unsuccessful completion of a remote command.
diff --git a/ssh/session_test.go b/ssh/session_test.go
index 5cff58a..cc26573 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -12,71 +12,60 @@
 	"io"
 	"io/ioutil"
 	"math/rand"
-	"net"
 	"testing"
 
 	"code.google.com/p/go.crypto/ssh/terminal"
 )
 
-type serverType func(*serverChan, *testing.T)
+type serverType func(Channel, <-chan *Request, *testing.T)
 
 // dial constructs a new test server and returns a *ClientConn.
-func dial(handler serverType, t *testing.T) *ClientConn {
-	l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+func dial(handler serverType, t *testing.T) *Client {
+	c1, c2, err := netPipe()
 	if err != nil {
-		t.Fatalf("unable to listen: %v", err)
+		t.Fatalf("netPipe: %v", err)
 	}
+
 	go func() {
-		defer l.Close()
-		conn, err := l.Accept()
+		defer c1.Close()
+		conf := ServerConfig{
+			NoClientAuth: true,
+		}
+		conf.AddHostKey(testSigners["rsa"])
+
+		_, chans, reqs, err := NewServerConn(c1, &conf)
 		if err != nil {
-			t.Errorf("Unable to accept: %v", err)
-			return
+			t.Fatalf("Unable to handshake: %v", err)
 		}
-		defer conn.Close()
-		if err := conn.Handshake(); err != nil {
-			t.Errorf("Unable to handshake: %v", err)
-			return
-		}
-		done := make(chan struct{})
-		for {
-			ch, err := conn.Accept()
-			if err == io.EOF || err == io.ErrUnexpectedEOF {
-				return
-			}
-			// We sometimes get ECONNRESET rather than EOF.
-			if _, ok := err.(*net.OpError); ok {
-				return
-			}
-			if err != nil {
-				t.Errorf("Unable to accept incoming channel request: %v", err)
-				return
-			}
-			if ch.ChannelType() != "session" {
-				ch.Reject(UnknownChannelType, "unknown channel type")
+		go DiscardRequests(reqs)
+
+		for newCh := range chans {
+			if newCh.ChannelType() != "session" {
+				newCh.Reject(UnknownChannelType, "unknown channel type")
 				continue
 			}
-			ch.Accept()
+
+			ch, inReqs, err := newCh.Accept()
+			if err != nil {
+				t.Errorf("Accept: %v", err)
+				continue
+			}
 			go func() {
-				defer close(done)
-				handler(ch.(*serverChan), t)
+				handler(ch, inReqs, t)
 			}()
 		}
-		<-done
 	}()
 
 	config := &ClientConfig{
 		User: "testuser",
-		Auth: []ClientAuth{
-			ClientAuthPassword(clientPassword),
-		},
 	}
 
-	c, err := Dial("tcp", l.Addr().String(), config)
+	conn, chans, reqs, err := NewClientConn(c2, "", config)
 	if err != nil {
 		t.Fatalf("unable to dial remote side: %v", err)
 	}
-	return c
+
+	return NewClient(conn, chans, reqs)
 }
 
 // Test a simple string is returned to session.Stdout.
@@ -330,164 +319,6 @@
 	}
 }
 
-func TestInvalidServerMessage(t *testing.T) {
-	conn := dial(sendInvalidRecord, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
-	}
-	// Make sure that we closed all the clientChans when the connection
-	// failed.
-	session.wait()
-
-	defer session.Close()
-}
-
-// In the wild some clients (and servers) send zero sized window updates.
-// Test that the client can continue after receiving a zero sized update.
-func TestClientZeroWindowAdjust(t *testing.T) {
-	conn := dial(sendZeroWindowAdjust, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
-	}
-	defer session.Close()
-
-	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %v", err)
-	}
-	err = session.Wait()
-	if err != nil {
-		t.Fatalf("expected nil but got %v", err)
-	}
-}
-
-// In the wild some clients (and servers) send zero sized window updates.
-// Test that the server can continue after receiving a zero size update.
-func TestServerZeroWindowAdjust(t *testing.T) {
-	conn := dial(exitStatusZeroHandler, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
-	}
-	defer session.Close()
-
-	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %v", err)
-	}
-
-	// send a bogus zero sized window update
-	session.clientChan.sendWindowAdj(0)
-
-	err = session.Wait()
-	if err != nil {
-		t.Fatalf("expected nil but got %v", err)
-	}
-}
-
-// Verify that the client never sends a packet larger than maxpacket.
-func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
-	conn := dial(discardHandler, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("failed to request new session: %v", err)
-	}
-	defer session.Close()
-	stdin, err := session.StdinPipe()
-	if err != nil {
-		t.Fatalf("failed to obtain stdinpipe: %v", err)
-	}
-	const size = 100 * 1000
-	for i := 0; i < 10; i++ {
-		n, err := stdin.Write(make([]byte, size))
-		if n != size || err != nil {
-			t.Fatalf("failed to write: %d, %v", n, err)
-		}
-	}
-}
-
-// Verify that the client never accepts a packet larger than maxpacket.
-func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
-	conn := dial(largeSendHandler, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
-	}
-	defer session.Close()
-	out, err := session.StdoutPipe()
-	if err != nil {
-		t.Fatalf("Unable to connect to Stdout: %v", err)
-	}
-	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %v", err)
-	}
-	if _, err := ioutil.ReadAll(out); err != nil {
-		t.Fatalf("failed to read: %v", err)
-	}
-}
-
-func TestClientCannotSendAfterEOF(t *testing.T) {
-	conn := dial(exitWithoutSignalOrStatus, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
-	}
-	defer session.Close()
-	in, err := session.StdinPipe()
-	if err != nil {
-		t.Fatalf("Unable to connect channel stdin: %v", err)
-	}
-	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %v", err)
-	}
-	if err := in.Close(); err != nil {
-		t.Fatalf("Unable to close stdin: %v", err)
-	}
-	if _, err := in.Write([]byte("foo")); err == nil {
-		t.Fatalf("Session write should fail")
-	}
-}
-
-func TestClientCannotSendAfterClose(t *testing.T) {
-	conn := dial(exitWithoutSignalOrStatus, t)
-	defer conn.Close()
-	session, err := conn.NewSession()
-	if err != nil {
-		t.Fatalf("Unable to request new session: %v", err)
-	}
-	defer session.Close()
-	in, err := session.StdinPipe()
-	if err != nil {
-		t.Fatalf("Unable to connect channel stdin: %v", err)
-	}
-	if err := session.Shell(); err != nil {
-		t.Fatalf("Unable to execute command: %v", err)
-	}
-	// close underlying channel
-	if err := session.channel.Close(); err != nil {
-		t.Fatalf("Unable to close session: %v", err)
-	}
-	if _, err := in.Write([]byte("foo")); err == nil {
-		t.Fatalf("Session write should fail")
-	}
-}
-
-func TestClientCannotSendHugePacket(t *testing.T) {
-	// client and server use the same transport write code so this
-	// test suffices for both.
-	conn := dial(shellHandler, t)
-	defer conn.Close()
-	if err := conn.transport.writePacket(make([]byte, maxPacket*2)); err == nil {
-		t.Fatalf("huge packet write should fail")
-	}
-}
-
 // windowTestBytes is the number of bytes that we'll send to the SSH server.
 const windowTestBytes = 16000 * 200
 
@@ -560,93 +391,104 @@
 }
 
 type exitStatusMsg struct {
-	PeersId   uint32
-	Request   string
-	WantReply bool
-	Status    uint32
+	Status uint32
 }
 
 type exitSignalMsg struct {
-	PeersId    uint32
-	Request    string
-	WantReply  bool
 	Signal     string
 	CoreDumped bool
 	Errmsg     string
 	Lang       string
 }
 
-func newServerShell(ch *serverChan, prompt string) *ServerTerminal {
-	term := terminal.NewTerminal(ch, prompt)
-	return &ServerTerminal{
-		Term:    term,
-		Channel: ch,
+func handleTerminalRequests(in <-chan *Request) {
+	for req := range in {
+		ok := false
+		switch req.Type {
+		case "shell":
+			ok = true
+			if len(req.Payload) > 0 {
+				// We don't accept any commands, only the default shell.
+				ok = false
+			}
+		case "env":
+			ok = true
+		}
+		req.Reply(ok, nil)
 	}
 }
 
-func exitStatusZeroHandler(ch *serverChan, t *testing.T) {
+func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
+	term := terminal.NewTerminal(ch, prompt)
+	go handleTerminalRequests(in)
+	return term
+}
+
+func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
 	// this string is returned to stdout
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
 	sendStatus(0, ch, t)
 }
 
-func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) {
+func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
 	sendStatus(15, ch, t)
 }
 
-func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) {
+func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
 	sendStatus(15, ch, t)
 	sendSignal("TERM", ch, t)
 }
 
-func exitSignalHandler(ch *serverChan, t *testing.T) {
+func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
 	sendSignal("TERM", ch, t)
 }
 
-func exitSignalUnknownHandler(ch *serverChan, t *testing.T) {
+func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
 	sendSignal("SYS", ch, t)
 }
 
-func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) {
+func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
 }
 
-func shellHandler(ch *serverChan, t *testing.T) {
+func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
 	// this string is returned to stdout
-	shell := newServerShell(ch, "golang")
+	shell := newServerShell(ch, in, "golang")
 	readLine(shell, t)
 	sendStatus(0, ch, t)
 }
 
 // Ignores the command, writes fixed strings to stderr and stdout.
 // Strings are "this-is-stdout." and "this-is-stderr.".
-func fixedOutputHandler(ch *serverChan, t *testing.T) {
+func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
+	_, err := ch.Read(nil)
 
-	_, err := ch.Read(make([]byte, 0))
-	if _, ok := err.(ChannelRequest); !ok {
+	req, ok := <-in
+	if !ok {
 		t.Fatalf("error: expected channel request, got: %#v", err)
 		return
 	}
+
 	// ignore request, always send some text
-	ch.AckRequest(true)
+	req.Reply(true, nil)
 
 	_, err = io.WriteString(ch, "this-is-stdout.")
 	if err != nil {
@@ -659,84 +501,39 @@
 	sendStatus(0, ch, t)
 }
 
-func readLine(shell *ServerTerminal, t *testing.T) {
+func readLine(shell *terminal.Terminal, t *testing.T) {
 	if _, err := shell.ReadLine(); err != nil && err != io.EOF {
 		t.Errorf("unable to read line: %v", err)
 	}
 }
 
-func sendStatus(status uint32, ch *serverChan, t *testing.T) {
+func sendStatus(status uint32, ch Channel, t *testing.T) {
 	msg := exitStatusMsg{
-		PeersId:   ch.remoteId,
-		Request:   "exit-status",
-		WantReply: false,
-		Status:    status,
+		Status: status,
 	}
-	if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+	if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
 		t.Errorf("unable to send status: %v", err)
 	}
 }
 
-func sendSignal(signal string, ch *serverChan, t *testing.T) {
+func sendSignal(signal string, ch Channel, t *testing.T) {
 	sig := exitSignalMsg{
-		PeersId:    ch.remoteId,
-		Request:    "exit-signal",
-		WantReply:  false,
 		Signal:     signal,
 		CoreDumped: false,
 		Errmsg:     "Process terminated",
 		Lang:       "en-GB-oed",
 	}
-	if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil {
+	if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
 		t.Errorf("unable to send signal: %v", err)
 	}
 }
 
-func sendInvalidRecord(ch *serverChan, t *testing.T) {
+func discardHandler(ch Channel, t *testing.T) {
 	defer ch.Close()
-	packet := make([]byte, 1+4+4+1)
-	packet[0] = msgChannelData
-	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
-	marshalUint32(packet[5:], 1)
-	packet[9] = 42
-
-	if err := ch.writePacket(packet); err != nil {
-		t.Errorf("unable send invalid record: %v", err)
-	}
-}
-
-func sendZeroWindowAdjust(ch *serverChan, t *testing.T) {
-	defer ch.Close()
-	// send a bogus zero sized window update
-	ch.sendWindowAdj(0)
-	shell := newServerShell(ch, "> ")
-	readLine(shell, t)
-	sendStatus(0, ch, t)
-}
-
-func discardHandler(ch *serverChan, t *testing.T) {
-	defer ch.Close()
-	// grow the window to avoid being fooled by
-	// the initial 1 << 14 window.
-	ch.sendWindowAdj(1024 * 1024)
 	io.Copy(ioutil.Discard, ch)
 }
 
-func largeSendHandler(ch *serverChan, t *testing.T) {
-	defer ch.Close()
-	// grow the window to avoid being fooled by
-	// the initial 1 << 14 window.
-	ch.sendWindowAdj(1024 * 1024)
-	shell := newServerShell(ch, "> ")
-	readLine(shell, t)
-	// try to send more than the 32k window
-	// will allow
-	if err := ch.writePacket(make([]byte, 128*1024)); err == nil {
-		t.Errorf("wrote packet larger than 32k")
-	}
-}
-
-func echoHandler(ch *serverChan, t *testing.T) {
+func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
 	if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
 		t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
@@ -773,17 +570,59 @@
 	return written, nil
 }
 
-func channelKeepaliveSender(ch *serverChan, t *testing.T) {
+func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
 	defer ch.Close()
-	shell := newServerShell(ch, "> ")
+	shell := newServerShell(ch, in, "> ")
 	readLine(shell, t)
-	msg := channelRequestMsg{
-		PeersId:   ch.remoteId,
-		Request:   "keepalive@openssh.com",
-		WantReply: true,
-	}
-	if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+	if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
 		t.Errorf("unable to send channel keepalive request: %v", err)
 	}
 	sendStatus(0, ch, t)
 }
+
+func TestClientWriteEOF(t *testing.T) {
+	conn := dial(simpleEchoHandler, t)
+	defer conn.Close()
+
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer session.Close()
+	stdin, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("StdinPipe failed: %v", err)
+	}
+	stdout, err := session.StdoutPipe()
+	if err != nil {
+		t.Fatalf("StdoutPipe failed: %v", err)
+	}
+
+	data := []byte(`0000`)
+	_, err = stdin.Write(data)
+	if err != nil {
+		t.Fatalf("Write failed: %v", err)
+	}
+	stdin.Close()
+
+	res, err := ioutil.ReadAll(stdout)
+	if err != nil {
+		t.Fatalf("Read failed: %v", err)
+	}
+
+	if !bytes.Equal(data, res) {
+		t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
+	}
+}
+
+func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
+	defer ch.Close()
+	data, err := ioutil.ReadAll(ch)
+	if err != nil {
+		t.Errorf("handler read error: %v", err)
+	}
+	_, err = ch.Write(data)
+	if err != nil {
+		t.Errorf("handler write error: %v", err)
+	}
+}
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 74fc1a7..5a4fa8b 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -16,10 +16,11 @@
 	"time"
 )
 
-// Listen requests the remote peer open a listening socket
-// on addr. Incoming connections will be available by calling
-// Accept on the returned net.Listener.
-func (c *ClientConn) Listen(n, addr string) (net.Listener, error) {
+// Listen requests the remote peer open a listening socket on
+// addr. Incoming connections will be available by calling Accept on
+// the returned net.Listener. The listener must be serviced, or the
+// SSH connection may hang.
+func (c *Client) Listen(n, addr string) (net.Listener, error) {
 	laddr, err := net.ResolveTCPAddr(n, addr)
 	if err != nil {
 		return nil, err
@@ -59,7 +60,7 @@
 
 // autoPortListenWorkaround simulates automatic port allocation by
 // trying random ports repeatedly.
-func (c *ClientConn) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
+func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
 	var sshListener net.Listener
 	var err error
 	const tries = 10
@@ -77,44 +78,45 @@
 
 // RFC 4254 7.1
 type channelForwardMsg struct {
-	Message   string
-	WantReply bool
-	raddr     string
-	rport     uint32
+	addr  string
+	rport uint32
 }
 
 // ListenTCP requests the remote peer open a listening socket
 // on laddr. Incoming connections will be available by calling
 // Accept on the returned net.Listener.
-func (c *ClientConn) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
-	if laddr.Port == 0 && isBrokenOpenSSHVersion(c.serverVersion) {
+func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
+	if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
 		return c.autoPortListenWorkaround(laddr)
 	}
 
 	m := channelForwardMsg{
-		"tcpip-forward",
-		true, // sendGlobalRequest waits for a reply
 		laddr.IP.String(),
 		uint32(laddr.Port),
 	}
 	// send message
-	resp, err := c.sendGlobalRequest(m)
+	ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
 	if err != nil {
 		return nil, err
 	}
+	if !ok {
+		return nil, errors.New("ssh: tcpip-forward request denied by peer")
+	}
 
 	// If the original port was 0, then the remote side will
 	// supply a real port number in the response.
 	if laddr.Port == 0 {
-		port, _, ok := parseUint32(resp.Data)
-		if !ok {
-			return nil, errors.New("unable to parse response")
+		var p struct {
+			Port uint32
 		}
-		laddr.Port = int(port)
+		if err := Unmarshal(resp, &p); err != nil {
+			return nil, err
+		}
+		laddr.Port = int(p.Port)
 	}
 
 	// Register this forward, using the port number we obtained.
-	ch := c.forwardList.add(*laddr)
+	ch := c.forwards.add(*laddr)
 
 	return &tcpListener{laddr, c, ch}, nil
 }
@@ -137,7 +139,7 @@
 // arguments to add/remove/lookup should be address as specified in
 // the original forward-request.
 type forward struct {
-	c     *clientChan  // the ssh client channel underlying this forward
+	newCh NewChannel   // the ssh client channel underlying this forward
 	raddr *net.TCPAddr // the raddr of the incoming connection
 }
 
@@ -152,6 +154,31 @@
 	return f.c
 }
 
+func (l *forwardList) handleChannels(in <-chan NewChannel) {
+	for ch := range in {
+		laddr, rest, ok := parseTCPAddr(ch.ExtraData())
+		if !ok {
+			// invalid request
+			ch.Reject(ConnectionFailed, "could not parse TCP address")
+			continue
+		}
+
+		raddr, rest, ok := parseTCPAddr(rest)
+		if !ok {
+			// invalid request
+			ch.Reject(ConnectionFailed, "could not parse TCP address")
+			continue
+		}
+
+		if ok := l.forward(*laddr, *raddr, ch); !ok {
+			// Section 7.2, implementations MUST reject spurious incoming
+			// connections.
+			ch.Reject(Prohibited, "no forward for address")
+			continue
+		}
+	}
+}
+
 // remove removes the forward entry, and the channel feeding its
 // listener.
 func (l *forwardList) remove(addr net.TCPAddr) {
@@ -176,21 +203,22 @@
 	l.entries = nil
 }
 
-func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) {
+func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
 	l.Lock()
 	defer l.Unlock()
 	for _, f := range l.entries {
-		if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
-			return f.c, true
+		if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
+			f.c <- forward{ch, &raddr}
+			return true
 		}
 	}
-	return nil, false
+	return false
 }
 
 type tcpListener struct {
 	laddr *net.TCPAddr
 
-	conn *ClientConn
+	conn *Client
 	in   <-chan forward
 }
 
@@ -200,30 +228,33 @@
 	if !ok {
 		return nil, io.EOF
 	}
+	ch, incoming, err := s.newCh.Accept()
+	if err != nil {
+		return nil, err
+	}
+	go DiscardRequests(incoming)
+
 	return &tcpChanConn{
-		tcpChan: &tcpChan{
-			clientChan: s.c,
-			Reader:     s.c.stdout,
-			Writer:     s.c.stdin,
-		},
-		laddr: l.laddr,
-		raddr: s.raddr,
+		Channel: ch,
+		laddr:   l.laddr,
+		raddr:   s.raddr,
 	}, nil
 }
 
 // Close closes the listener.
 func (l *tcpListener) Close() error {
 	m := channelForwardMsg{
-		"cancel-tcpip-forward",
-		true,
 		l.laddr.IP.String(),
 		uint32(l.laddr.Port),
 	}
-	l.conn.forwardList.remove(*l.laddr)
-	if _, err := l.conn.sendGlobalRequest(m); err != nil {
-		return err
+
+	// this also closes the listener.
+	l.conn.forwards.remove(*l.laddr)
+	ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
+	if err == nil && !ok {
+		err = errors.New("ssh: cancel-tcpip-forward failed")
 	}
-	return nil
+	return err
 }
 
 // Addr returns the listener's network address.
@@ -233,7 +264,7 @@
 
 // Dial initiates a connection to the addr from the remote host.
 // The resulting connection has a zero LocalAddr() and RemoteAddr().
-func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
+func (c *Client) Dial(n, addr string) (net.Conn, error) {
 	// Parse the address into host and numeric port.
 	host, portString, err := net.SplitHostPort(addr)
 	if err != nil {
@@ -253,7 +284,7 @@
 		return nil, err
 	}
 	return &tcpChanConn{
-		tcpChan: ch,
+		Channel: ch,
 		laddr:   zeroAddr,
 		raddr:   zeroAddr,
 	}, nil
@@ -262,7 +293,7 @@
 // DialTCP connects to the remote address raddr on the network net,
 // which must be "tcp", "tcp4", or "tcp6".  If laddr is not nil, it is used
 // as the local address for the connection.
-func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
+func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
 	if laddr == nil {
 		laddr = &net.TCPAddr{
 			IP:   net.IPv4zero,
@@ -274,7 +305,7 @@
 		return nil, err
 	}
 	return &tcpChanConn{
-		tcpChan: ch,
+		Channel: ch,
 		laddr:   laddr,
 		raddr:   raddr,
 	}, nil
@@ -282,54 +313,32 @@
 
 // RFC 4254 7.2
 type channelOpenDirectMsg struct {
-	ChanType      string
-	PeersId       uint32
-	PeersWindow   uint32
-	MaxPacketSize uint32
-	raddr         string
-	rport         uint32
-	laddr         string
-	lport         uint32
+	raddr string
+	rport uint32
+	laddr string
+	lport uint32
 }
 
-// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
-// strings and are expected to be resolvable at the remote end.
-func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
-	ch := c.newChan(c.transport)
-	if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
-		ChanType:      "direct-tcpip",
-		PeersId:       ch.localId,
-		PeersWindow:   channelWindowSize,
-		MaxPacketSize: channelMaxPacketSize,
-		raddr:         raddr,
-		rport:         uint32(rport),
-		laddr:         laddr,
-		lport:         uint32(lport),
-	})); err != nil {
-		c.chanList.remove(ch.localId)
-		return nil, err
+func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
+	msg := channelOpenDirectMsg{
+		raddr: raddr,
+		rport: uint32(rport),
+		laddr: laddr,
+		lport: uint32(lport),
 	}
-	if err := ch.waitForChannelOpenResponse(); err != nil {
-		c.chanList.remove(ch.localId)
-		return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err)
-	}
-	return &tcpChan{
-		clientChan: ch,
-		Reader:     ch.stdout,
-		Writer:     ch.stdin,
-	}, nil
+	ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
+	go DiscardRequests(in)
+	return ch, err
 }
 
 type tcpChan struct {
-	*clientChan // the backing channel
-	io.Reader
-	io.Writer
+	Channel // the backing channel
 }
 
 // tcpChanConn fulfills the net.Conn interface without
 // the tcpChan having to hold laddr or raddr directly.
 type tcpChanConn struct {
-	*tcpChan
+	Channel
 	laddr, raddr net.Addr
 }
 
diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go
index 7fa9fc4..f1265cb 100644
--- a/ssh/tcpip_test.go
+++ b/ssh/tcpip_test.go
@@ -1,3 +1,7 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
 package ssh
 
 import (
diff --git a/ssh/terminal/util_windows.go b/ssh/terminal/util_windows.go
new file mode 100644
index 0000000..0a454e0
--- /dev/null
+++ b/ssh/terminal/util_windows.go
@@ -0,0 +1,171 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build windows
+
+// Package terminal provides support functions for dealing with terminals, as
+// commonly found on UNIX systems.
+//
+// Putting a terminal into raw mode is the most common requirement:
+//
+// 	oldState, err := terminal.MakeRaw(0)
+// 	if err != nil {
+// 	        panic(err)
+// 	}
+// 	defer terminal.Restore(0, oldState)
+package terminal
+
+import (
+	"io"
+	"syscall"
+	"unsafe"
+)
+
+const (
+	enableLineInput       = 2
+	enableEchoInput       = 4
+	enableProcessedInput  = 1
+	enableWindowInput     = 8
+	enableMouseInput      = 16
+	enableInsertMode      = 32
+	enableQuickEditMode   = 64
+	enableExtendedFlags   = 128
+	enableAutoPosition    = 256
+	enableProcessedOutput = 1
+	enableWrapAtEolOutput = 2
+)
+
+var kernel32 = syscall.NewLazyDLL("kernel32.dll")
+
+var (
+	procGetConsoleMode             = kernel32.NewProc("GetConsoleMode")
+	procSetConsoleMode             = kernel32.NewProc("SetConsoleMode")
+	procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
+)
+
+type (
+	short int16
+	word  uint16
+
+	coord struct {
+		x short
+		y short
+	}
+	smallRect struct {
+		left   short
+		top    short
+		right  short
+		bottom short
+	}
+	consoleScreenBufferInfo struct {
+		size              coord
+		cursorPosition    coord
+		attributes        word
+		window            smallRect
+		maximumWindowSize coord
+	}
+)
+
+type State struct {
+	mode uint32
+}
+
+// IsTerminal returns true if the given file descriptor is a terminal.
+func IsTerminal(fd int) bool {
+	var st uint32
+	r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+	return r != 0 && e == 0
+}
+
+// MakeRaw put the terminal connected to the given file descriptor into raw
+// mode and returns the previous state of the terminal so that it can be
+// restored.
+func MakeRaw(fd int) (*State, error) {
+	var st uint32
+	_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+	if e != 0 {
+		return nil, error(e)
+	}
+	st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
+	_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
+	if e != 0 {
+		return nil, error(e)
+	}
+	return &State{st}, nil
+}
+
+// GetState returns the current state of a terminal which may be useful to
+// restore the terminal after a signal.
+func GetState(fd int) (*State, error) {
+	var st uint32
+	_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+	if e != 0 {
+		return nil, error(e)
+	}
+	return &State{st}, nil
+}
+
+// Restore restores the terminal connected to the given file descriptor to a
+// previous state.
+func Restore(fd int, state *State) error {
+	_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0)
+	return err
+}
+
+// GetSize returns the dimensions of the given terminal.
+func GetSize(fd int) (width, height int, err error) {
+	var info consoleScreenBufferInfo
+	_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0)
+	if e != 0 {
+		return 0, 0, error(e)
+	}
+	return int(info.size.x), int(info.size.y), nil
+}
+
+// ReadPassword reads a line of input from a terminal without local echo.  This
+// is commonly used for inputting passwords and other sensitive data. The slice
+// returned does not include the \n.
+func ReadPassword(fd int) ([]byte, error) {
+	var st uint32
+	_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+	if e != 0 {
+		return nil, error(e)
+	}
+	old := st
+
+	st &^= (enableEchoInput)
+	st |= (enableProcessedInput | enableLineInput | enableProcessedOutput)
+	_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
+	if e != 0 {
+		return nil, error(e)
+	}
+
+	defer func() {
+		syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
+	}()
+
+	var buf [16]byte
+	var ret []byte
+	for {
+		n, err := syscall.Read(syscall.Handle(fd), buf[:])
+		if err != nil {
+			return nil, err
+		}
+		if n == 0 {
+			if len(ret) == 0 {
+				return nil, io.EOF
+			}
+			break
+		}
+		if buf[n-1] == '\n' {
+			n--
+		}
+		ret = append(ret, buf[:n]...)
+		if n < len(buf) {
+			break
+		}
+	}
+
+	return ret, nil
+}
diff --git a/ssh/test/agent_unix_test.go b/ssh/test/agent_unix_test.go
new file mode 100644
index 0000000..26c88eb
--- /dev/null
+++ b/ssh/test/agent_unix_test.go
@@ -0,0 +1,50 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package test
+
+import (
+	"bytes"
+	"testing"
+
+	"code.google.com/p/go.crypto/ssh"
+	"code.google.com/p/go.crypto/ssh/agent"
+)
+
+func TestAgentForward(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+	defer conn.Close()
+
+	keyring := agent.NewKeyring()
+	keyring.Add(testPrivateKeys["dsa"], nil, "")
+	pub := testPublicKeys["dsa"]
+
+	sess, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("NewSession: %v", err)
+	}
+	if err := agent.RequestAgentForwarding(sess); err != nil {
+		t.Fatalf("RequestAgentForwarding: %v", err)
+	}
+
+	if err := agent.ForwardToAgent(conn, keyring); err != nil {
+		t.Fatalf("SetupForwardKeyring: %v", err)
+	}
+	out, err := sess.CombinedOutput("ssh-add -L")
+	if err != nil {
+		t.Fatalf("running ssh-add: %v, out %s", err, out)
+	}
+	key, _, _, _, err := ssh.ParseAuthorizedKey(out)
+	if err != nil {
+		t.Fatalf("ParseAuthorizedKey(%q): %v", out, err)
+	}
+
+	if !bytes.Equal(key.Marshal(), pub.Marshal()) {
+		t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub))
+	}
+}
diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go
new file mode 100644
index 0000000..d4f7226
--- /dev/null
+++ b/ssh/test/cert_test.go
@@ -0,0 +1,47 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package test
+
+import (
+	"crypto/rand"
+	"testing"
+
+	"code.google.com/p/go.crypto/ssh"
+)
+
+func TestCertLogin(t *testing.T) {
+	s := newServer(t)
+	defer s.Shutdown()
+
+	// Use a key different from the default.
+	clientKey := testSigners["dsa"]
+	caAuthKey := testSigners["ecdsa"]
+	cert := &ssh.Certificate{
+		Key:             clientKey.PublicKey(),
+		ValidPrincipals: []string{username()},
+		CertType:        ssh.UserCert,
+		ValidBefore:     ssh.CertTimeInfinity,
+	}
+	if err := cert.SignCert(rand.Reader, caAuthKey); err != nil {
+		t.Fatalf("SetSignature: %v", err)
+	}
+
+	certSigner, err := ssh.NewCertSigner(cert, clientKey)
+	if err != nil {
+		t.Fatalf("NewCertSigner: %v", err)
+	}
+
+	conf := &ssh.ClientConfig{
+		User: username(),
+	}
+	conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
+	client, err := s.TryDial(conf)
+	if err != nil {
+		t.Fatalf("TryDial: %v", err)
+	}
+	client.Close()
+}
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index 3a57c10..881a9da 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// +build darwin freebsd linux netbsd openbsd plan9
+// +build darwin freebsd linux netbsd openbsd
 
 package test
 
diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go
index bd7307d..d8d35a5 100644
--- a/ssh/test/session_test.go
+++ b/ssh/test/session_test.go
@@ -11,6 +11,7 @@
 import (
 	"bytes"
 	"code.google.com/p/go.crypto/ssh"
+	"errors"
 	"io"
 	"strings"
 	"testing"
@@ -38,12 +39,13 @@
 	defer server.Shutdown()
 
 	conf := clientConfig()
-	k := conf.HostKeyChecker.(*storedHostKey)
+	hostDB := hostKeyDB()
+	conf.HostKeyCallback = hostDB.Check
 
 	// change the keys.
-	k.keys[ssh.KeyAlgoRSA][25]++
-	k.keys[ssh.KeyAlgoDSA][25]++
-	k.keys[ssh.KeyAlgoECDSA256][25]++
+	hostDB.keys[ssh.KeyAlgoRSA][25]++
+	hostDB.keys[ssh.KeyAlgoDSA][25]++
+	hostDB.keys[ssh.KeyAlgoECDSA256][25]++
 
 	conn, err := server.TryDial(conf)
 	if err == nil {
@@ -54,6 +56,53 @@
 	}
 }
 
+func TestRunCommandStdin(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+	defer conn.Close()
+
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("session failed: %v", err)
+	}
+	defer session.Close()
+
+	r, w := io.Pipe()
+	defer r.Close()
+	defer w.Close()
+	session.Stdin = r
+
+	err = session.Run("true")
+	if err != nil {
+		t.Fatalf("session failed: %v", err)
+	}
+}
+
+func TestRunCommandStdinError(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conn := server.Dial(clientConfig())
+	defer conn.Close()
+
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("session failed: %v", err)
+	}
+	defer session.Close()
+
+	r, w := io.Pipe()
+	defer r.Close()
+	session.Stdin = r
+	pipeErr := errors.New("closing write end of pipe")
+	w.CloseWithError(pipeErr)
+
+	err = session.Run("true")
+	if err != pipeErr {
+		t.Fatalf("expected %v, found %v", pipeErr, err)
+	}
+}
+
 func TestRunCommandFailed(t *testing.T) {
 	server := newServer(t)
 	defer server.Shutdown()
@@ -107,7 +156,7 @@
 		t.Fatalf("unable to acquire stdout pipe: %s", err)
 	}
 
-	err = session.Start("dd if=/dev/urandom bs=2048 count=1")
+	err = session.Start("dd if=/dev/urandom bs=2048 count=1024")
 	if err != nil {
 		t.Fatalf("unable to execute remote command: %s", err)
 	}
@@ -118,11 +167,53 @@
 		t.Fatalf("error reading from remote stdout: %s", err)
 	}
 
-	if n != 2048 {
+	if n != 2048*1024 {
 		t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n)
 	}
 }
 
+func TestKeyChange(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+	conf := clientConfig()
+	hostDB := hostKeyDB()
+	conf.HostKeyCallback = hostDB.Check
+	conf.RekeyThreshold = 1024
+	conn := server.Dial(conf)
+	defer conn.Close()
+
+	for i := 0; i < 4; i++ {
+		session, err := conn.NewSession()
+		if err != nil {
+			t.Fatalf("unable to create new session: %s", err)
+		}
+
+		stdout, err := session.StdoutPipe()
+		if err != nil {
+			t.Fatalf("unable to acquire stdout pipe: %s", err)
+		}
+
+		err = session.Start("dd if=/dev/urandom bs=1024 count=1")
+		if err != nil {
+			t.Fatalf("unable to execute remote command: %s", err)
+		}
+		buf := new(bytes.Buffer)
+		n, err := io.Copy(buf, stdout)
+		if err != nil {
+			t.Fatalf("error reading from remote stdout: %s", err)
+		}
+
+		want := int64(1024)
+		if n != want {
+			t.Fatalf("Expected %d bytes but read only %d from remote command", want, n)
+		}
+	}
+
+	if changes := hostDB.checkCount; changes < 4 {
+		t.Errorf("got %d key changes, want 4", changes)
+	}
+}
+
 func TestInvalidTerminalMode(t *testing.T) {
 	server := newServer(t)
 	defer server.Shutdown()
@@ -183,3 +274,44 @@
 		t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput)
 	}
 }
+
+func TestCiphers(t *testing.T) {
+	var config ssh.Config
+	config.SetDefaults()
+	cipherOrder := config.Ciphers
+
+	for _, ciph := range cipherOrder {
+		server := newServer(t)
+		defer server.Shutdown()
+		conf := clientConfig()
+		conf.Ciphers = []string{ciph}
+		// Don't fail if sshd doesnt have the cipher.
+		conf.Ciphers = append(conf.Ciphers, cipherOrder...)
+		conn, err := server.TryDial(conf)
+		if err == nil {
+			conn.Close()
+		} else {
+			t.Fatalf("failed for cipher %q", ciph)
+		}
+	}
+}
+
+func TestMACs(t *testing.T) {
+	var config ssh.Config
+	config.SetDefaults()
+	macOrder := config.MACs
+
+	for _, mac := range macOrder {
+		server := newServer(t)
+		defer server.Shutdown()
+		conf := clientConfig()
+		conf.MACs = []string{mac}
+		// Don't fail if sshd doesnt have the MAC.
+		conf.MACs = append(conf.MACs, macOrder...)
+		if conn, err := server.TryDial(conf); err == nil {
+			conn.Close()
+		} else {
+			t.Fatalf("failed for MAC %q", mac)
+		}
+	}
+}
diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go
index ee06b60..a2eb935 100644
--- a/ssh/test/tcpip_test.go
+++ b/ssh/test/tcpip_test.go
@@ -9,39 +9,38 @@
 // direct-tcpip functional tests
 
 import (
+	"io"
 	"net"
-	"net/http"
 	"testing"
 )
 
-func TestTCPIPHTTP(t *testing.T) {
-	// google.com will generate at least one redirect, possibly three
-	// depending on your location.
-	doTest(t, "http://google.com")
-}
-
-func TestTCPIPHTTPS(t *testing.T) {
-	doTest(t, "https://encrypted.google.com/")
-}
-
-func doTest(t *testing.T, url string) {
+func TestDial(t *testing.T) {
 	server := newServer(t)
 	defer server.Shutdown()
-	conn := server.Dial(clientConfig())
-	defer conn.Close()
+	sshConn := server.Dial(clientConfig())
+	defer sshConn.Close()
 
-	tr := &http.Transport{
-		Dial: func(n, addr string) (net.Conn, error) {
-			return conn.Dial(n, addr)
-		},
-	}
-	client := &http.Client{
-		Transport: tr,
-	}
-	resp, err := client.Get(url)
+	l, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
-		t.Fatalf("unable to proxy: %s", err)
+		t.Fatalf("Listen: %v", err)
 	}
-	// got a body without error
-	t.Log(resp)
+	defer l.Close()
+
+	go func() {
+		for {
+			c, err := l.Accept()
+			if err != nil {
+				break
+			}
+
+			io.WriteString(c, c.RemoteAddr().String())
+			c.Close()
+		}
+	}()
+
+	conn, err := sshConn.Dial("tcp", l.Addr().String())
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+	defer conn.Close()
 }
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 86df3f4..f44c65d 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -2,7 +2,7 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// +build darwin freebsd linux netbsd openbsd plan9
+// +build darwin freebsd linux netbsd openbsd
 
 package test
 
@@ -11,7 +11,6 @@
 import (
 	"bytes"
 	"fmt"
-	"io"
 	"io/ioutil"
 	"log"
 	"net"
@@ -23,13 +22,14 @@
 	"text/template"
 
 	"code.google.com/p/go.crypto/ssh"
+	"code.google.com/p/go.crypto/ssh/testdata"
 )
 
 const sshd_config = `
 Protocol 2
-HostKey {{.Dir}}/ssh_host_rsa_key
-HostKey {{.Dir}}/ssh_host_dsa_key
-HostKey {{.Dir}}/ssh_host_ecdsa_key
+HostKey {{.Dir}}/id_rsa
+HostKey {{.Dir}}/id_dsa
+HostKey {{.Dir}}/id_ecdsa
 Pidfile {{.Dir}}/sshd.pid
 #UsePrivilegeSeparation no
 KeyRegenerationInterval 3600
@@ -41,41 +41,14 @@
 StrictModes no
 RSAAuthentication yes
 PubkeyAuthentication yes
-AuthorizedKeysFile	{{.Dir}}/authorized_keys
+AuthorizedKeysFile	{{.Dir}}/id_user.pub
+TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
 IgnoreRhosts yes
 RhostsRSAAuthentication no
 HostbasedAuthentication no
 `
 
-var (
-	configTmpl   template.Template
-	privateKey   ssh.Signer
-	hostKeyRSA   ssh.Signer
-	hostKeyECDSA ssh.Signer
-	hostKeyDSA   ssh.Signer
-)
-
-func init() {
-	template.Must(configTmpl.Parse(sshd_config))
-
-	for n, k := range map[string]*ssh.Signer{
-		"ssh_host_ecdsa_key": &hostKeyECDSA,
-		"ssh_host_rsa_key":   &hostKeyRSA,
-		"ssh_host_dsa_key":   &hostKeyDSA,
-	} {
-		var err error
-		*k, err = ssh.ParsePrivateKey([]byte(keys[n]))
-		if err != nil {
-			panic(fmt.Sprintf("ParsePrivateKey(%q): %v", n, err))
-		}
-	}
-
-	var err error
-	privateKey, err = ssh.ParsePrivateKey([]byte(testClientPrivateKey))
-	if err != nil {
-		panic(fmt.Sprintf("ParsePrivateKey: %v", err))
-	}
-}
+var configTmpl = template.Must(template.New("").Parse(sshd_config))
 
 type server struct {
 	t          *testing.T
@@ -107,36 +80,44 @@
 type storedHostKey struct {
 	// keys map from an algorithm string to binary key data.
 	keys map[string][]byte
+
+	// checkCount counts the Check calls. Used for testing
+	// rekeying.
+	checkCount int
 }
 
 func (k *storedHostKey) Add(key ssh.PublicKey) {
 	if k.keys == nil {
 		k.keys = map[string][]byte{}
 	}
-	k.keys[key.PublicKeyAlgo()] = ssh.MarshalPublicKey(key)
+	k.keys[key.Type()] = key.Marshal()
 }
 
-func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
-	if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
+func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
+	k.checkCount++
+	algo := key.Type()
+
+	if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
 		return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
 	}
 	return nil
 }
 
-func clientConfig() *ssh.ClientConfig {
-	keyChecker := storedHostKey{}
-	keyChecker.Add(hostKeyECDSA.PublicKey())
-	keyChecker.Add(hostKeyRSA.PublicKey())
-	keyChecker.Add(hostKeyDSA.PublicKey())
+func hostKeyDB() *storedHostKey {
+	keyChecker := &storedHostKey{}
+	keyChecker.Add(testPublicKeys["ecdsa"])
+	keyChecker.Add(testPublicKeys["rsa"])
+	keyChecker.Add(testPublicKeys["dsa"])
+	return keyChecker
+}
 
-	kc := new(keychain)
-	kc.keys = append(kc.keys, privateKey)
+func clientConfig() *ssh.ClientConfig {
 	config := &ssh.ClientConfig{
 		User: username(),
-		Auth: []ssh.ClientAuth{
-			ssh.ClientAuthKeyring(kc),
+		Auth: []ssh.AuthMethod{
+			ssh.PublicKeys(testSigners["user"]),
 		},
-		HostKeyChecker: &keyChecker,
+		HostKeyCallback: hostKeyDB().Check,
 	}
 	return config
 }
@@ -171,7 +152,7 @@
 	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
 }
 
-func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
+func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
 	sshd, err := exec.LookPath("sshd")
 	if err != nil {
 		s.t.Skipf("skipping test: %v", err)
@@ -197,10 +178,14 @@
 		s.t.Fatalf("s.cmd.Start: %v", err)
 	}
 	s.clientConn = c1
-	return ssh.Client(c1, config)
+	conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
+	if err != nil {
+		return nil, err
+	}
+	return ssh.NewClient(conn, chans, reqs), nil
 }
 
-func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
+func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
 	conn, err := s.TryDial(config)
 	if err != nil {
 		s.t.Fail()
@@ -226,6 +211,17 @@
 	s.cleanup()
 }
 
+func writeFile(path string, contents []byte) {
+	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
+	if err != nil {
+		panic(err)
+	}
+	defer f.Close()
+	if _, err := f.Write(contents); err != nil {
+		panic(err)
+	}
+}
+
 // newServer returns a new mock ssh server.
 func newServer(t *testing.T) *server {
 	dir, err := ioutil.TempDir("", "sshtest")
@@ -244,15 +240,10 @@
 	}
 	f.Close()
 
-	for k, v := range keys {
-		f, err := os.OpenFile(filepath.Join(dir, k), os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
-		if err != nil {
-			t.Fatal(err)
-		}
-		if _, err := f.Write([]byte(v)); err != nil {
-			t.Fatal(err)
-		}
-		f.Close()
+	for k, v := range testdata.PEMBytes {
+		filename := "id_" + k
+		writeFile(filepath.Join(dir, filename), v)
+		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
 	}
 
 	return &server{
@@ -265,32 +256,3 @@
 		},
 	}
 }
-
-// keychain implements the ClientKeyring interface.
-type keychain struct {
-	keys []ssh.Signer
-}
-
-func (k *keychain) Key(i int) (ssh.PublicKey, error) {
-	if i < 0 || i >= len(k.keys) {
-		return nil, nil
-	}
-	return k.keys[i].PublicKey(), nil
-}
-
-func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
-	return k.keys[i].Sign(rand, data)
-}
-
-func (k *keychain) loadPEM(file string) error {
-	buf, err := ioutil.ReadFile(file)
-	if err != nil {
-		return err
-	}
-	key, err := ssh.ParsePrivateKey(buf)
-	if err != nil {
-		return err
-	}
-	k.keys = append(k.keys, key)
-	return nil
-}
diff --git a/ssh/test/testdata_test.go b/ssh/test/testdata_test.go
new file mode 100644
index 0000000..7f50fbe
--- /dev/null
+++ b/ssh/test/testdata_test.go
@@ -0,0 +1,64 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
+// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
+// instances.
+
+package test
+
+import (
+	"crypto/rand"
+	"fmt"
+
+	"code.google.com/p/go.crypto/ssh"
+	"code.google.com/p/go.crypto/ssh/testdata"
+)
+
+var (
+	testPrivateKeys map[string]interface{}
+	testSigners     map[string]ssh.Signer
+	testPublicKeys  map[string]ssh.PublicKey
+)
+
+func init() {
+	var err error
+
+	n := len(testdata.PEMBytes)
+	testPrivateKeys = make(map[string]interface{}, n)
+	testSigners = make(map[string]ssh.Signer, n)
+	testPublicKeys = make(map[string]ssh.PublicKey, n)
+	for t, k := range testdata.PEMBytes {
+		testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
+		if err != nil {
+			panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
+		}
+		testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
+		if err != nil {
+			panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
+		}
+		testPublicKeys[t] = testSigners[t].PublicKey()
+	}
+
+	// Create a cert and sign it for use in tests.
+	testCert := &ssh.Certificate{
+		Nonce:           []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+		ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+		ValidAfter:      0,                              // unix epoch
+		ValidBefore:     ssh.CertTimeInfinity,           // The end of currently representable time.
+		Reserved:        []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+		Key:             testPublicKeys["ecdsa"],
+		SignatureKey:    testPublicKeys["rsa"],
+		Permissions: ssh.Permissions{
+			CriticalOptions: map[string]string{},
+			Extensions:      map[string]string{},
+		},
+	}
+	testCert.SignCert(rand.Reader, testSigners["rsa"])
+	testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
+	testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
+	if err != nil {
+		panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
+	}
+}
diff --git a/ssh/testdata/doc.go b/ssh/testdata/doc.go
new file mode 100644
index 0000000..4302486
--- /dev/null
+++ b/ssh/testdata/doc.go
@@ -0,0 +1,8 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package contains test data shared between the various subpackages of
+// the code.google.com/p/go.crypto/ssh package. Under no circumstance should
+// this data be used for production code.
+package testdata
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go
new file mode 100644
index 0000000..5ff1c0e
--- /dev/null
+++ b/ssh/testdata/keys.go
@@ -0,0 +1,43 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testdata
+
+var PEMBytes = map[string][]byte{
+	"dsa": []byte(`-----BEGIN DSA PRIVATE KEY-----
+MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
+lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
+EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
+nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
+2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
+juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
+FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
+DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
+nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
+Fmsr0W6fHB9nhS4/UXM8
+-----END DSA PRIVATE KEY-----
+`),
+	"ecdsa": []byte(`-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
+AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
+6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
+-----END EC PRIVATE KEY-----
+`),
+	"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
+r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
+tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
+nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
+2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
+y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
+rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
+-----END RSA PRIVATE KEY-----
+`),
+	"user": []byte(`-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
+AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD
+PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w==
+-----END EC PRIVATE KEY-----
+`),
+}
diff --git a/ssh/testdata_test.go b/ssh/testdata_test.go
new file mode 100644
index 0000000..302fdc8
--- /dev/null
+++ b/ssh/testdata_test.go
@@ -0,0 +1,63 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
+// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
+// instances.
+
+package ssh
+
+import (
+	"crypto/rand"
+	"fmt"
+
+	"code.google.com/p/go.crypto/ssh/testdata"
+)
+
+var (
+	testPrivateKeys map[string]interface{}
+	testSigners     map[string]Signer
+	testPublicKeys  map[string]PublicKey
+)
+
+func init() {
+	var err error
+
+	n := len(testdata.PEMBytes)
+	testPrivateKeys = make(map[string]interface{}, n)
+	testSigners = make(map[string]Signer, n)
+	testPublicKeys = make(map[string]PublicKey, n)
+	for t, k := range testdata.PEMBytes {
+		testPrivateKeys[t], err = ParseRawPrivateKey(k)
+		if err != nil {
+			panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
+		}
+		testSigners[t], err = NewSignerFromKey(testPrivateKeys[t])
+		if err != nil {
+			panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
+		}
+		testPublicKeys[t] = testSigners[t].PublicKey()
+	}
+
+	// Create a cert and sign it for use in tests.
+	testCert := &Certificate{
+		Nonce:           []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+		ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+		ValidAfter:      0,                              // unix epoch
+		ValidBefore:     CertTimeInfinity,               // The end of currently representable time.
+		Reserved:        []byte{},                       // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+		Key:             testPublicKeys["ecdsa"],
+		SignatureKey:    testPublicKeys["rsa"],
+		Permissions: Permissions{
+			CriticalOptions: map[string]string{},
+			Extensions:      map[string]string{},
+		},
+	}
+	testCert.SignCert(rand.Reader, testSigners["rsa"])
+	testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
+	testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"])
+	if err != nil {
+		panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
+	}
+}
diff --git a/ssh/transport.go b/ssh/transport.go
index 46fa262..4f68b04 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -6,26 +6,12 @@
 
 import (
 	"bufio"
-	"crypto/cipher"
-	"crypto/subtle"
-	"encoding/binary"
 	"errors"
-	"hash"
 	"io"
-	"net"
-	"sync"
 )
 
 const (
-	packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
-
-	// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
-	// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
-	// indicates implementations SHOULD be able to handle larger packet sizes, but then
-	// waffles on about reasonable limits.
-	//
-	// OpenSSH caps their maxPacket at 256kb so we choose to do the same.
-	maxPacket = 256 * 1024
+	gcmCipherID = "aes128-gcm@openssh.com"
 )
 
 // packetConn represents a transport that implements packet based
@@ -41,225 +27,128 @@
 	Close() error
 }
 
-// transport represents the SSH connection to the remote peer.
+// transport is the keyingTransport that implements the SSH packet
+// protocol.
 type transport struct {
-	reader
-	writer
+	reader connectionState
+	writer connectionState
 
-	net.Conn
+	bufReader *bufio.Reader
+	bufWriter *bufio.Writer
+	rand      io.Reader
+
+	io.Closer
 
 	// Initial H used for the session ID. Once assigned this does
 	// not change, even during subsequent key exchanges.
 	sessionID []byte
 }
 
-// reader represents the incoming connection state.
-type reader struct {
-	io.Reader
-	common
+func (t *transport) getSessionID() []byte {
+	if t.sessionID == nil {
+		panic("session ID not set yet")
+	}
+	s := make([]byte, len(t.sessionID))
+	copy(s, t.sessionID)
+	return s
 }
 
-// writer represents the outgoing connection state.
-type writer struct {
-	sync.Mutex // protects writer.Writer from concurrent writes
-	*bufio.Writer
-	rand io.Reader
-	common
+// packetCipher represents a combination of SSH encryption/MAC
+// protocol.  A single instance should be used for one direction only.
+type packetCipher interface {
+	// writePacket encrypts the packet and writes it to w. The
+	// contents of the packet are generally scrambled.
+	writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
+
+	// readPacket reads and decrypts a packet of data. The
+	// returned packet may be overwritten by future calls of
+	// readPacket.
+	readPacket(seqnum uint32, r io.Reader) ([]byte, error)
+}
+
+// connectionState represents one side (read or write) of the
+// connection. This is necessary because each direction has its own
+// keys, and can even have its own algorithms
+type connectionState struct {
+	packetCipher
+	seqNum           uint32
+	dir              direction
+	pendingKeyChange chan packetCipher
 }
 
 // prepareKeyChange sets up key material for a keychange. The key changes in
 // both directions are triggered by reading and writing a msgNewKey packet
 // respectively.
 func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
-	t.writer.cipherAlgo = algs.wCipher
-	t.writer.macAlgo = algs.wMAC
-	t.writer.compressionAlgo = algs.wCompression
-
-	t.reader.cipherAlgo = algs.rCipher
-	t.reader.macAlgo = algs.rMAC
-	t.reader.compressionAlgo = algs.rCompression
-
 	if t.sessionID == nil {
 		t.sessionID = kexResult.H
 	}
 
 	kexResult.SessionID = t.sessionID
-	t.reader.pendingKeyChange <- kexResult
-	t.writer.pendingKeyChange <- kexResult
+
+	if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
+		return err
+	} else {
+		t.reader.pendingKeyChange <- ciph
+	}
+
+	if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
+		return err
+	} else {
+		t.writer.pendingKeyChange <- ciph
+	}
+
 	return nil
 }
 
-// common represents the cipher state needed to process messages in a single
-// direction.
-type common struct {
-	seqNum uint32
-	mac    hash.Hash
-	cipher cipher.Stream
-
-	cipherAlgo      string
-	macAlgo         string
-	compressionAlgo string
-
-	dir              direction
-	pendingKeyChange chan *kexResult
+// Read and decrypt next packet.
+func (t *transport) readPacket() ([]byte, error) {
+	return t.reader.readPacket(t.bufReader)
 }
 
-// Read and decrypt a single packet from the remote peer.
-func (r *reader) readPacket() ([]byte, error) {
-	var lengthBytes = make([]byte, 5)
-	var macSize uint32
-	if _, err := io.ReadFull(r, lengthBytes); err != nil {
-		return nil, err
+func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
+	packet, err := s.packetCipher.readPacket(s.seqNum, r)
+	s.seqNum++
+	if err == nil && len(packet) == 0 {
+		err = errors.New("ssh: zero length packet")
 	}
 
-	r.cipher.XORKeyStream(lengthBytes, lengthBytes)
-
-	if r.mac != nil {
-		r.mac.Reset()
-		seqNumBytes := []byte{
-			byte(r.seqNum >> 24),
-			byte(r.seqNum >> 16),
-			byte(r.seqNum >> 8),
-			byte(r.seqNum),
-		}
-		r.mac.Write(seqNumBytes)
-		r.mac.Write(lengthBytes)
-		macSize = uint32(r.mac.Size())
-	}
-
-	length := binary.BigEndian.Uint32(lengthBytes[0:4])
-	paddingLength := uint32(lengthBytes[4])
-
-	if length <= paddingLength+1 {
-		return nil, errors.New("ssh: invalid packet length, packet too small")
-	}
-
-	if length > maxPacket {
-		return nil, errors.New("ssh: invalid packet length, packet too large")
-	}
-
-	packet := make([]byte, length-1+macSize)
-	if _, err := io.ReadFull(r, packet); err != nil {
-		return nil, err
-	}
-	mac := packet[length-1:]
-	r.cipher.XORKeyStream(packet, packet[:length-1])
-
-	if r.mac != nil {
-		r.mac.Write(packet[:length-1])
-		if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
-			return nil, errors.New("ssh: MAC failure")
-		}
-	}
-
-	r.seqNum++
-	packet = packet[:length-paddingLength-1]
-
 	if len(packet) > 0 && packet[0] == msgNewKeys {
 		select {
-		case k := <-r.pendingKeyChange:
-			if err := r.setupKeys(r.dir, k); err != nil {
-				return nil, err
-			}
+		case cipher := <-s.pendingKeyChange:
+			s.packetCipher = cipher
 		default:
 			return nil, errors.New("ssh: got bogus newkeys message.")
 		}
 	}
-	return packet, nil
+
+	// The packet may point to an internal buffer, so copy the
+	// packet out here.
+	fresh := make([]byte, len(packet))
+	copy(fresh, packet)
+
+	return fresh, err
 }
 
-// Read and decrypt next packet discarding debug and noop messages.
-func (t *transport) readPacket() ([]byte, error) {
-	for {
-		packet, err := t.reader.readPacket()
-		if err != nil {
-			return nil, err
-		}
-		if len(packet) == 0 {
-			return nil, errors.New("ssh: zero length packet")
-		}
-
-		if packet[0] != msgIgnore && packet[0] != msgDebug {
-			return packet, nil
-		}
-	}
-	panic("unreachable")
+func (t *transport) writePacket(packet []byte) error {
+	return t.writer.writePacket(t.bufWriter, t.rand, packet)
 }
 
-// Encrypt and send a packet of data to the remote peer.
-func (w *writer) writePacket(packet []byte) error {
+func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
 	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
 
-	if len(packet) > maxPacket {
-		return errors.New("ssh: packet too large")
-	}
-	w.Mutex.Lock()
-	defer w.Mutex.Unlock()
-
-	paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
-	if paddingLength < 4 {
-		paddingLength += packetSizeMultiple
-	}
-
-	length := len(packet) + 1 + paddingLength
-	lengthBytes := []byte{
-		byte(length >> 24),
-		byte(length >> 16),
-		byte(length >> 8),
-		byte(length),
-		byte(paddingLength),
-	}
-	padding := make([]byte, paddingLength)
-	_, err := io.ReadFull(w.rand, padding)
+	err := s.packetCipher.writePacket(s.seqNum, w, rand, packet)
 	if err != nil {
 		return err
 	}
-
-	if w.mac != nil {
-		w.mac.Reset()
-		seqNumBytes := []byte{
-			byte(w.seqNum >> 24),
-			byte(w.seqNum >> 16),
-			byte(w.seqNum >> 8),
-			byte(w.seqNum),
-		}
-		w.mac.Write(seqNumBytes)
-		w.mac.Write(lengthBytes)
-		w.mac.Write(packet)
-		w.mac.Write(padding)
-	}
-
-	// TODO(dfc) lengthBytes, packet and padding should be
-	// subslices of a single buffer
-	w.cipher.XORKeyStream(lengthBytes, lengthBytes)
-	w.cipher.XORKeyStream(packet, packet)
-	w.cipher.XORKeyStream(padding, padding)
-
-	if _, err := w.Write(lengthBytes); err != nil {
-		return err
-	}
-	if _, err := w.Write(packet); err != nil {
-		return err
-	}
-	if _, err := w.Write(padding); err != nil {
-		return err
-	}
-
-	if w.mac != nil {
-		if _, err := w.Write(w.mac.Sum(nil)); err != nil {
-			return err
-		}
-	}
-
-	w.seqNum++
 	if err = w.Flush(); err != nil {
 		return err
 	}
-
+	s.seqNum++
 	if changeKeys {
 		select {
-		case k := <-w.pendingKeyChange:
-			err = w.setupKeys(w.dir, k)
+		case cipher := <-s.pendingKeyChange:
+			s.packetCipher = cipher
 		default:
 			panic("ssh: no key material for msgNewKeys")
 		}
@@ -267,24 +156,20 @@
 	return err
 }
 
-func newTransport(conn net.Conn, rand io.Reader, isClient bool) *transport {
+func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
 	t := &transport{
-		reader: reader{
-			Reader: bufio.NewReader(conn),
-			common: common{
-				cipher:           noneCipher{},
-				pendingKeyChange: make(chan *kexResult, 1),
-			},
+		bufReader: bufio.NewReader(rwc),
+		bufWriter: bufio.NewWriter(rwc),
+		rand:      rand,
+		reader: connectionState{
+			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
+			pendingKeyChange: make(chan packetCipher, 1),
 		},
-		writer: writer{
-			Writer: bufio.NewWriter(conn),
-			rand:   rand,
-			common: common{
-				cipher:           noneCipher{},
-				pendingKeyChange: make(chan *kexResult, 1),
-			},
+		writer: connectionState{
+			packetCipher:     &streamPacketCipher{cipher: noneCipher{}},
+			pendingKeyChange: make(chan packetCipher, 1),
 		},
-		Conn: conn,
+		Closer: rwc,
 	}
 	if isClient {
 		t.reader.dir = serverKeys
@@ -303,48 +188,64 @@
 	macKeyTag []byte
 }
 
-// TODO(dfc) can this be made a constant ?
 var (
 	serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
 	clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
 )
 
+// generateKeys generates key material for IV, MAC and encryption.
+func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
+	cipherMode := cipherModes[algs.Cipher]
+	macMode := macModes[algs.MAC]
+
+	iv = make([]byte, cipherMode.ivSize)
+	key = make([]byte, cipherMode.keySize)
+	macKey = make([]byte, macMode.keySize)
+
+	generateKeyMaterial(iv, d.ivTag, kex)
+	generateKeyMaterial(key, d.keyTag, kex)
+	generateKeyMaterial(macKey, d.macKeyTag, kex)
+	return
+}
+
 // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
 // described in RFC 4253, section 6.4. direction should either be serverKeys
 // (to setup server->client keys) or clientKeys (for client->server keys).
-func (c *common) setupKeys(d direction, r *kexResult) error {
-	cipherMode := cipherModes[c.cipherAlgo]
-	macMode := macModes[c.macAlgo]
+func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
+	iv, key, macKey := generateKeys(d, algs, kex)
 
-	iv := make([]byte, cipherMode.ivSize)
-	key := make([]byte, cipherMode.keySize)
-	macKey := make([]byte, macMode.keySize)
+	if algs.Cipher == gcmCipherID {
+		return newGCMCipher(iv, key, macKey)
+	}
 
-	h := r.Hash.New()
-	generateKeyMaterial(iv, d.ivTag, r.K, r.H, r.SessionID, h)
-	generateKeyMaterial(key, d.keyTag, r.K, r.H, r.SessionID, h)
-	generateKeyMaterial(macKey, d.macKeyTag, r.K, r.H, r.SessionID, h)
-
-	c.mac = macMode.new(macKey)
+	c := &streamPacketCipher{
+		mac: macModes[algs.MAC].new(macKey),
+	}
+	c.macResult = make([]byte, c.mac.Size())
 
 	var err error
-	c.cipher, err = cipherMode.createCipher(key, iv)
-	return err
+	c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
 }
 
 // generateKeyMaterial fills out with key material generated from tag, K, H
 // and sessionId, as specified in RFC 4253, section 7.2.
-func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
+func generateKeyMaterial(out, tag []byte, r *kexResult) {
 	var digestsSoFar []byte
 
+	h := r.Hash.New()
 	for len(out) > 0 {
 		h.Reset()
-		h.Write(K)
-		h.Write(H)
+		h.Write(r.K)
+		h.Write(r.H)
 
 		if len(digestsSoFar) == 0 {
 			h.Write(tag)
-			h.Write(sessionId)
+			h.Write(r.SessionID)
 		} else {
 			h.Write(digestsSoFar)
 		}
diff --git a/ssh/transport_test.go b/ssh/transport_test.go
index 3320114..92d83ab 100644
--- a/ssh/transport_test.go
+++ b/ssh/transport_test.go
@@ -6,6 +6,8 @@
 
 import (
 	"bytes"
+	"crypto/rand"
+	"encoding/binary"
 	"strings"
 	"testing"
 )
@@ -67,3 +69,41 @@
 		}
 	}
 }
+
+type closerBuffer struct {
+	bytes.Buffer
+}
+
+func (b *closerBuffer) Close() error {
+	return nil
+}
+
+func TestTransportMaxPacketWrite(t *testing.T) {
+	buf := &closerBuffer{}
+	tr := newTransport(buf, rand.Reader, true)
+	huge := make([]byte, maxPacket+1)
+	err := tr.writePacket(huge)
+	if err == nil {
+		t.Errorf("transport accepted write for a huge packet.")
+	}
+}
+
+func TestTransportMaxPacketReader(t *testing.T) {
+	var header [5]byte
+	huge := make([]byte, maxPacket+128)
+	binary.BigEndian.PutUint32(header[0:], uint32(len(huge)))
+	// padding.
+	header[4] = 0
+
+	buf := &closerBuffer{}
+	buf.Write(header[:])
+	buf.Write(huge)
+
+	tr := newTransport(buf, rand.Reader, true)
+	_, err := tr.readPacket()
+	if err == nil {
+		t.Errorf("transport succeeded reading huge packet.")
+	} else if !strings.Contains(err.Error(), "large") {
+		t.Errorf("got %q, should mention %q", err.Error(), "large")
+	}
+}