go.crypto/ssh: import gosshnew.
See https://groups.google.com/d/msg/Golang-nuts/AoVxQ4bB5XQ/i8kpMxdbVlEJ
R=hanwen
CC=golang-codereviews
https://golang.org/cl/86190043
diff --git a/ssh/agent/client.go b/ssh/agent/client.go
new file mode 100644
index 0000000..9c11d32
--- /dev/null
+++ b/ssh/agent/client.go
@@ -0,0 +1,563 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+ Package agent implements a client to an ssh-agent daemon.
+
+References:
+ [PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
+*/
+package agent
+
+import (
+ "bytes"
+ "crypto/dsa"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "sync"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+// Agent represents the capabilities of an ssh-agent.
+type Agent interface {
+ // List returns the identities known to the agent.
+ List() ([]*Key, error)
+
+ // Sign has the agent sign the data using a protocol 2 key as defined
+ // in [PROTOCOL.agent] section 2.6.2.
+ Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
+
+ // Insert adds a private key to the agent. If a certificate
+ // is given, that certificate is added as public key.
+ Add(s interface{}, cert *ssh.Certificate, comment string) error
+
+ // Remove removes all identities with the given public key.
+ Remove(key ssh.PublicKey) error
+
+ // RemoveAll removes all identities.
+ RemoveAll() error
+
+ // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
+ Lock(passphrase []byte) error
+
+ // Unlock undoes the effect of Lock
+ Unlock(passphrase []byte) error
+
+ // Signers returns signers for all the known keys.
+ Signers() ([]ssh.Signer, error)
+}
+
+// See [PROTOCOL.agent], section 3.
+const (
+ agentRequestV1Identities = 1
+
+ // 3.2 Requests from client to agent for protocol 2 key operations
+ agentAddIdentity = 17
+ agentRemoveIdentity = 18
+ agentRemoveAllIdentities = 19
+ agentAddIdConstrained = 25
+
+ // 3.3 Key-type independent requests from client to agent
+ agentAddSmartcardKey = 20
+ agentRemoveSmartcardKey = 21
+ agentLock = 22
+ agentUnlock = 23
+ agentAddSmartcardKeyConstrained = 26
+
+ // 3.7 Key constraint identifiers
+ agentConstrainLifetime = 1
+ agentConstrainConfirm = 2
+)
+
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
+// is a sanity check, not a limit in the spec.
+const maxAgentResponseBytes = 16 << 20
+
+// Agent messages:
+// These structures mirror the wire format of the corresponding ssh agent
+// messages found in [PROTOCOL.agent].
+
+// 3.4 Generic replies from agent to client
+const agentFailure = 5
+
+type failureAgentMsg struct{}
+
+const agentSuccess = 6
+
+type successAgentMsg struct{}
+
+// See [PROTOCOL.agent], section 2.5.2.
+const agentRequestIdentities = 11
+
+type requestIdentitiesAgentMsg struct{}
+
+// See [PROTOCOL.agent], section 2.5.2.
+const agentIdentitiesAnswer = 12
+
+type identitiesAnswerAgentMsg struct {
+ NumKeys uint32 `sshtype:"12"`
+ Keys []byte `ssh:"rest"`
+}
+
+// See [PROTOCOL.agent], section 2.6.2.
+const agentSignRequest = 13
+
+type signRequestAgentMsg struct {
+ KeyBlob []byte `sshtype:"13"`
+ Data []byte
+ Flags uint32
+}
+
+// See [PROTOCOL.agent], section 2.6.2.
+
+// 3.6 Replies from agent to client for protocol 2 key operations
+const agentSignResponse = 14
+
+type signResponseAgentMsg struct {
+ SigBlob []byte `sshtype:"14"`
+}
+
+type publicKey struct {
+ Format string
+ Rest []byte `ssh:"rest"`
+}
+
+// Key represents a protocol 2 public key as defined in
+// [PROTOCOL.agent], section 2.5.2.
+type Key struct {
+ Format string
+ Blob []byte
+ Comment string
+}
+
+func clientErr(err error) error {
+ return fmt.Errorf("agent: client error: %v", err)
+}
+
+// String returns the storage form of an agent key with the format, base64
+// encoded serialized key, and the comment if it is not empty.
+func (k *Key) String() string {
+ s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
+
+ if k.Comment != "" {
+ s += " " + k.Comment
+ }
+
+ return s
+}
+
+// Type returns the public key type.
+func (k *Key) Type() string {
+ return k.Format
+}
+
+// Marshal returns key blob to satisfy the ssh.PublicKey interface.
+func (k *Key) Marshal() []byte {
+ return k.Blob
+}
+
+// Verify satisfies the ssh.PublicKey interface, but is not
+// implemented for agent keys.
+func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
+ return errors.New("agent: agent key does not know how to verify")
+}
+
+type wireKey struct {
+ Format string
+ Rest []byte `ssh:"rest"`
+}
+
+func parseKey(in []byte) (out *Key, rest []byte, err error) {
+ var record struct {
+ Blob []byte
+ Comment string
+ Rest []byte `ssh:"rest"`
+ }
+
+ if err := ssh.Unmarshal(in, &record); err != nil {
+ return nil, nil, err
+ }
+
+ var wk wireKey
+ if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
+ return nil, nil, err
+ }
+
+ return &Key{
+ Format: wk.Format,
+ Blob: record.Blob,
+ Comment: record.Comment,
+ }, record.Rest, nil
+}
+
+// client is a client for an ssh-agent process.
+type client struct {
+ // conn is typically a *net.UnixConn
+ conn io.ReadWriter
+ // mu is used to prevent concurrent access to the agent
+ mu sync.Mutex
+}
+
+// NewClient returns an Agent that talks to an ssh-agent process over
+// the given connection.
+func NewClient(rw io.ReadWriter) Agent {
+ return &client{conn: rw}
+}
+
+// call sends an RPC to the agent. On success, the reply is
+// unmarshaled into reply and replyType is set to the first byte of
+// the reply, which contains the type of the message.
+func (c *client) call(req []byte) (reply interface{}, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ msg := make([]byte, 4+len(req))
+ binary.BigEndian.PutUint32(msg, uint32(len(req)))
+ copy(msg[4:], req)
+ if _, err = c.conn.Write(msg); err != nil {
+ return nil, clientErr(err)
+ }
+
+ var respSizeBuf [4]byte
+ if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
+ return nil, clientErr(err)
+ }
+ respSize := binary.BigEndian.Uint32(respSizeBuf[:])
+ if respSize > maxAgentResponseBytes {
+ return nil, clientErr(err)
+ }
+
+ buf := make([]byte, respSize)
+ if _, err = io.ReadFull(c.conn, buf); err != nil {
+ return nil, clientErr(err)
+ }
+ reply, err = unmarshal(buf)
+ if err != nil {
+ return nil, clientErr(err)
+ }
+ return reply, err
+}
+
+func (c *client) simpleCall(req []byte) error {
+ resp, err := c.call(req)
+ if err != nil {
+ return err
+ }
+ if _, ok := resp.(*successAgentMsg); ok {
+ return nil
+ }
+ return errors.New("agent: failure")
+}
+
+func (c *client) RemoveAll() error {
+ return c.simpleCall([]byte{agentRemoveAllIdentities})
+}
+
+func (c *client) Remove(key ssh.PublicKey) error {
+ req := ssh.Marshal(&agentRemoveIdentityMsg{
+ KeyBlob: key.Marshal(),
+ })
+ return c.simpleCall(req)
+}
+
+func (c *client) Lock(passphrase []byte) error {
+ req := ssh.Marshal(&agentLockMsg{
+ Passphrase: passphrase,
+ })
+ return c.simpleCall(req)
+}
+
+func (c *client) Unlock(passphrase []byte) error {
+ req := ssh.Marshal(&agentUnlockMsg{
+ Passphrase: passphrase,
+ })
+ return c.simpleCall(req)
+}
+
+// List returns the identities known to the agent.
+func (c *client) List() ([]*Key, error) {
+ // see [PROTOCOL.agent] section 2.5.2.
+ req := []byte{agentRequestIdentities}
+
+ msg, err := c.call(req)
+ if err != nil {
+ return nil, err
+ }
+
+ switch msg := msg.(type) {
+ case *identitiesAnswerAgentMsg:
+ if msg.NumKeys > maxAgentResponseBytes/8 {
+ return nil, errors.New("ssh: too many keys in agent reply")
+ }
+ keys := make([]*Key, msg.NumKeys)
+ data := msg.Keys
+ for i := uint32(0); i < msg.NumKeys; i++ {
+ var key *Key
+ var err error
+ if key, data, err = parseKey(data); err != nil {
+ return nil, err
+ }
+ keys[i] = key
+ }
+ return keys, nil
+ case *failureAgentMsg:
+ return nil, errors.New("ssh: failed to list keys")
+ }
+ panic("unreachable")
+}
+
+// Sign has the agent sign the data using a protocol 2 key as defined
+// in [PROTOCOL.agent] section 2.6.2.
+func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+ req := ssh.Marshal(signRequestAgentMsg{
+ KeyBlob: key.Marshal(),
+ Data: data,
+ })
+
+ msg, err := c.call(req)
+ if err != nil {
+ return nil, err
+ }
+
+ switch msg := msg.(type) {
+ case *signResponseAgentMsg:
+ var sig ssh.Signature
+ if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
+ return nil, err
+ }
+
+ return &sig, nil
+ case *failureAgentMsg:
+ return nil, errors.New("ssh: failed to sign challenge")
+ }
+ panic("unreachable")
+}
+
+// unmarshal parses an agent message in packet, returning the parsed
+// form and the message type of packet.
+func unmarshal(packet []byte) (interface{}, error) {
+ if len(packet) < 1 {
+ return nil, errors.New("agent: empty packet")
+ }
+ var msg interface{}
+ switch packet[0] {
+ case agentFailure:
+ return new(failureAgentMsg), nil
+ case agentSuccess:
+ return new(successAgentMsg), nil
+ case agentIdentitiesAnswer:
+ msg = new(identitiesAnswerAgentMsg)
+ case agentSignResponse:
+ msg = new(signResponseAgentMsg)
+ default:
+ return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
+ }
+ if err := ssh.Unmarshal(packet, msg); err != nil {
+ return nil, err
+ }
+ return msg, nil
+}
+
+type rsaKeyMsg struct {
+ Type string `sshtype:"17"`
+ N *big.Int
+ E *big.Int
+ D *big.Int
+ Iqmp *big.Int // IQMP = Inverse Q Mod P
+ P *big.Int
+ Q *big.Int
+ Comments string
+}
+
+type dsaKeyMsg struct {
+ Type string `sshtype:"17"`
+ P *big.Int
+ Q *big.Int
+ G *big.Int
+ Y *big.Int
+ X *big.Int
+ Comments string
+}
+
+type ecdsaKeyMsg struct {
+ Type string `sshtype:"17"`
+ Curve string
+ KeyBytes []byte
+ D *big.Int
+ Comments string
+}
+
+// Insert adds a private key to the agent.
+func (c *client) insertKey(s interface{}, comment string) error {
+ var req []byte
+ switch k := s.(type) {
+ case *rsa.PrivateKey:
+ if len(k.Primes) != 2 {
+ return fmt.Errorf("ssh: unsupported RSA key with %d primes", len(k.Primes))
+ }
+ k.Precompute()
+ req = ssh.Marshal(rsaKeyMsg{
+ Type: ssh.KeyAlgoRSA,
+ N: k.N,
+ E: big.NewInt(int64(k.E)),
+ D: k.D,
+ Iqmp: k.Precomputed.Qinv,
+ P: k.Primes[0],
+ Q: k.Primes[1],
+ Comments: comment,
+ })
+ case *dsa.PrivateKey:
+ req = ssh.Marshal(dsaKeyMsg{
+ Type: ssh.KeyAlgoDSA,
+ P: k.P,
+ Q: k.Q,
+ G: k.G,
+ Y: k.Y,
+ X: k.X,
+ Comments: comment,
+ })
+ case *ecdsa.PrivateKey:
+ nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
+ req = ssh.Marshal(ecdsaKeyMsg{
+ Type: "ecdsa-sha2-" + nistID,
+ Curve: nistID,
+ KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
+ D: k.D,
+ Comments: comment,
+ })
+ default:
+ return fmt.Errorf("ssh: unsupported key type %T", s)
+ }
+ resp, err := c.call(req)
+ if err != nil {
+ return err
+ }
+ if _, ok := resp.(*successAgentMsg); ok {
+ return nil
+ }
+ return errors.New("ssh: failure")
+}
+
+type rsaCertMsg struct {
+ Type string `sshtype:"17"`
+ CertBytes []byte
+ D *big.Int
+ Iqmp *big.Int // IQMP = Inverse Q Mod P
+ P *big.Int
+ Q *big.Int
+ Comments string
+}
+
+type dsaCertMsg struct {
+ Type string `sshtype:"17"`
+ CertBytes []byte
+ X *big.Int
+ Comments string
+}
+
+type ecdsaCertMsg struct {
+ Type string `sshtype:"17"`
+ CertBytes []byte
+ D *big.Int
+ Comments string
+}
+
+// Insert adds a private key to the agent. If a certificate is given,
+// that certificate is added instead as public key.
+func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error {
+ if cert == nil {
+ return c.insertKey(s, comment)
+ } else {
+ return c.insertCert(s, cert, comment)
+ }
+}
+
+func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error {
+ var req []byte
+ switch k := s.(type) {
+ case *rsa.PrivateKey:
+ if len(k.Primes) != 2 {
+ return fmt.Errorf("ssh: unsupported RSA key with %d primes", len(k.Primes))
+ }
+ k.Precompute()
+ req = ssh.Marshal(rsaCertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ D: k.D,
+ Iqmp: k.Precomputed.Qinv,
+ P: k.Primes[0],
+ Q: k.Primes[1],
+ Comments: comment,
+ })
+ case *dsa.PrivateKey:
+ req = ssh.Marshal(dsaCertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ X: k.X,
+ Comments: comment,
+ })
+ case *ecdsa.PrivateKey:
+ req = ssh.Marshal(ecdsaCertMsg{
+ Type: cert.Type(),
+ CertBytes: cert.Marshal(),
+ D: k.D,
+ Comments: comment,
+ })
+ default:
+ return fmt.Errorf("ssh: unsupported key type %T", s)
+ }
+
+ signer, err := ssh.NewSignerFromKey(s)
+ if err != nil {
+ return err
+ }
+ if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+ return errors.New("ssh: signer and cert have different public key")
+ }
+
+ resp, err := c.call(req)
+ if err != nil {
+ return err
+ }
+ if _, ok := resp.(*successAgentMsg); ok {
+ return nil
+ }
+ return errors.New("ssh: failure")
+}
+
+// Signers provides a callback for client authentication.
+func (c *client) Signers() ([]ssh.Signer, error) {
+ keys, err := c.List()
+ if err != nil {
+ return nil, err
+ }
+
+ var result []ssh.Signer
+ for _, k := range keys {
+ result = append(result, &agentKeyringSigner{c, k})
+ }
+ return result, nil
+}
+
+type agentKeyringSigner struct {
+ agent *client
+ pub ssh.PublicKey
+}
+
+func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
+ return s.pub
+}
+
+func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
+ // The agent has its own entropy source, so the rand argument is ignored.
+ return s.agent.Sign(s.pub, data)
+}
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
new file mode 100644
index 0000000..aa99e27
--- /dev/null
+++ b/ssh/agent/client_test.go
@@ -0,0 +1,270 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "bytes"
+ "crypto/rand"
+ "errors"
+ "net"
+ "os"
+ "os/exec"
+ "strconv"
+ "testing"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
+ bin, err := exec.LookPath("ssh-agent")
+ if err != nil {
+ t.Skip("could not find ssh-agent")
+ }
+
+ cmd := exec.Command(bin, "-s")
+ out, err := cmd.Output()
+ if err != nil {
+ t.Fatalf("cmd.Output: %v", err)
+ }
+
+ /* Output looks like:
+
+ SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
+ SSH_AGENT_PID=15542; export SSH_AGENT_PID;
+ echo Agent pid 15542;
+ */
+ fields := bytes.Split(out, []byte(";"))
+ line := bytes.SplitN(fields[0], []byte("="), 2)
+ line[0] = bytes.TrimLeft(line[0], "\n")
+ if string(line[0]) != "SSH_AUTH_SOCK" {
+ t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
+ }
+ socket = string(line[1])
+
+ line = bytes.SplitN(fields[2], []byte("="), 2)
+ line[0] = bytes.TrimLeft(line[0], "\n")
+ if string(line[0]) != "SSH_AGENT_PID" {
+ t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
+ }
+ pidStr := line[1]
+ pid, err := strconv.Atoi(string(pidStr))
+ if err != nil {
+ t.Fatalf("Atoi(%q): %v", pidStr, err)
+ }
+
+ conn, err := net.Dial("unix", string(socket))
+ if err != nil {
+ t.Fatalf("net.Dial: %v", err)
+ }
+
+ ac := NewClient(conn)
+ return ac, socket, func() {
+ proc, _ := os.FindProcess(pid)
+ if proc != nil {
+ proc.Kill()
+ }
+ conn.Close()
+ }
+}
+
+func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) {
+ agent, _, cleanup := startAgent(t)
+ defer cleanup()
+
+ testAgentInterface(t, agent, key, cert)
+}
+
+func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) {
+ signer, err := ssh.NewSignerFromKey(key)
+ if err != nil {
+ t.Fatalf("NewSignerFromKey: %v", err)
+ }
+ // The agent should start up empty.
+ if keys, err := agent.List(); err != nil {
+ t.Fatalf("RequestIdentities: %v", err)
+ } else if len(keys) > 0 {
+ t.Fatalf("got %d keys, want 0: %v", len(keys), keys)
+ }
+
+ // Attempt to insert the key, with certificate if specified.
+ var pubKey ssh.PublicKey
+ if cert != nil {
+ err = agent.Add(key, cert, "comment")
+ pubKey = cert
+ } else {
+ err = agent.Add(key, nil, "comment")
+ pubKey = signer.PublicKey()
+ }
+ if err != nil {
+ t.Fatalf("insert: %v", err)
+ }
+
+ // Did the key get inserted successfully?
+ if keys, err := agent.List(); err != nil {
+ t.Fatalf("List: %v", err)
+ } else if len(keys) != 1 {
+ t.Fatalf("got %v, want 1 key", keys)
+ } else if keys[0].Comment != "comment" {
+ t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment")
+ } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) {
+ t.Fatalf("key mismatch")
+ }
+
+ // Can the agent make a valid signature?
+ data := []byte("hello")
+ sig, err := agent.Sign(pubKey, data)
+ if err != nil {
+ t.Fatalf("Sign: %v", err)
+ }
+
+ if err := pubKey.Verify(data, sig); err != nil {
+ t.Fatalf("key signature Verify: %v", err)
+ }
+}
+
+func TestAgent(t *testing.T) {
+ for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
+ t.Log(keyType)
+ testAgent(t, testPrivateKeys[keyType], nil)
+ }
+}
+
+func TestCert(t *testing.T) {
+ cert := &ssh.Certificate{
+ Key: testPublicKeys["rsa"],
+ ValidBefore: ssh.CertTimeInfinity,
+ CertType: ssh.UserCert,
+ }
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+
+ testAgent(t, testPrivateKeys["rsa"], cert)
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer listener.Close()
+ c1, err := net.Dial("tcp", listener.Addr().String())
+ if err != nil {
+ return nil, nil, err
+ }
+
+ c2, err := listener.Accept()
+ if err != nil {
+ c1.Close()
+ return nil, nil, err
+ }
+
+ return c1, c2, nil
+}
+
+func TestAuth(t *testing.T) {
+ a, b, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+
+ defer a.Close()
+ defer b.Close()
+
+ agent, _, cleanup := startAgent(t)
+ defer cleanup()
+
+ if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil {
+ t.Errorf("Add: %v", err)
+ }
+
+ serverConf := ssh.ServerConfig{}
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, nil
+ }
+
+ return nil, errors.New("pubkey rejected")
+ }
+
+ go func() {
+ conn, _, _, err := ssh.NewServerConn(a, &serverConf)
+ if err != nil {
+ t.Fatalf("Server: %v", err)
+ }
+ conn.Close()
+ }()
+
+ conf := ssh.ClientConfig{}
+ conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
+ conn, _, _, err := ssh.NewClientConn(b, "", &conf)
+ if err != nil {
+ t.Fatalf("NewClientConn: %v", err)
+ }
+ conn.Close()
+}
+
+func TestLockClient(t *testing.T) {
+ agent, _, cleanup := startAgent(t)
+ defer cleanup()
+ testLockAgent(agent, t)
+}
+
+func testLockAgent(agent Agent, t *testing.T) {
+ if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil {
+ t.Errorf("Add: %v", err)
+ }
+ if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil {
+ t.Errorf("Add: %v", err)
+ }
+ if keys, err := agent.List(); err != nil {
+ t.Errorf("List: %v", err)
+ } else if len(keys) != 2 {
+ t.Errorf("Want 2 keys, got %v", keys)
+ }
+
+ passphrase := []byte("secret")
+ if err := agent.Lock(passphrase); err != nil {
+ t.Errorf("Lock: %v", err)
+ }
+
+ if keys, err := agent.List(); err != nil {
+ t.Errorf("List: %v", err)
+ } else if len(keys) != 0 {
+ t.Errorf("Want 0 keys, got %v", keys)
+ }
+
+ signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"])
+ if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil {
+ t.Fatalf("Sign did not fail")
+ }
+
+ if err := agent.Remove(signer.PublicKey()); err == nil {
+ t.Fatalf("Remove did not fail")
+ }
+
+ if err := agent.RemoveAll(); err == nil {
+ t.Fatalf("RemoveAll did not fail")
+ }
+
+ if err := agent.Unlock(nil); err == nil {
+ t.Errorf("Unlock with wrong passphrase succeeded")
+ }
+ if err := agent.Unlock(passphrase); err != nil {
+ t.Errorf("Unlock: %v", err)
+ }
+
+ if err := agent.Remove(signer.PublicKey()); err != nil {
+ t.Fatalf("Remove: %v", err)
+ }
+
+ if keys, err := agent.List(); err != nil {
+ t.Errorf("List: %v", err)
+ } else if len(keys) != 1 {
+ t.Errorf("Want 1 keys, got %v", keys)
+ }
+}
diff --git a/ssh/agent/forward.go b/ssh/agent/forward.go
new file mode 100644
index 0000000..dd45c3e
--- /dev/null
+++ b/ssh/agent/forward.go
@@ -0,0 +1,103 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "errors"
+ "io"
+ "net"
+ "sync"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+// RequestAgentForwarding sets up agent forwarding for the session.
+// SetupForwardKeyring or SetupForwardAgent should be called to route
+// the authentication requests.
+func RequestAgentForwarding(session *ssh.Session) error {
+ ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
+ if err != nil {
+ return err
+ }
+ if !ok {
+ return errors.New("forwarding request denied")
+ }
+ return nil
+}
+
+// ForwardToAgent routes authentication requests to the given keyring.
+func ForwardToAgent(client *ssh.Client, keyring Agent) error {
+ channels := client.HandleChannelOpen(channelType)
+ if channels == nil {
+ return errors.New("agent: already have handler for " + channelType)
+ }
+
+ go func() {
+ for ch := range channels {
+ channel, reqs, err := ch.Accept()
+ if err != nil {
+ continue
+ }
+ go ssh.DiscardRequests(reqs)
+ go func() {
+ ServeAgent(keyring, channel)
+ channel.Close()
+ }()
+ }
+ }()
+ return nil
+}
+
+const channelType = "auth-agent@openssh.com"
+
+// ForwardToRemote routes authentication requests to the ssh-agent
+// process serving on the given unix socket.
+func ForwardToRemote(client *ssh.Client, addr string) error {
+ channels := client.HandleChannelOpen(channelType)
+ if channels == nil {
+ return errors.New("agent: already have handler for " + channelType)
+ }
+ conn, err := net.Dial("unix", addr)
+ if err != nil {
+ return err
+ }
+ conn.Close()
+
+ go func() {
+ for ch := range channels {
+ channel, reqs, err := ch.Accept()
+ if err != nil {
+ continue
+ }
+ go ssh.DiscardRequests(reqs)
+ go forwardUnixSocket(channel, addr)
+ }
+ }()
+ return nil
+}
+
+func forwardUnixSocket(channel ssh.Channel, addr string) {
+ conn, err := net.Dial("unix", addr)
+ if err != nil {
+ return
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ io.Copy(conn, channel)
+ conn.(*net.UnixConn).CloseWrite()
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(channel, conn)
+ channel.CloseWrite()
+ wg.Done()
+ }()
+
+ wg.Wait()
+ conn.Close()
+ channel.Close()
+}
diff --git a/ssh/agent/keyring.go b/ssh/agent/keyring.go
new file mode 100644
index 0000000..ecfa66f
--- /dev/null
+++ b/ssh/agent/keyring.go
@@ -0,0 +1,183 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/subtle"
+ "errors"
+ "fmt"
+ "sync"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+type privKey struct {
+ signer ssh.Signer
+ comment string
+}
+
+type keyring struct {
+ mu sync.Mutex
+ keys []privKey
+
+ locked bool
+ passphrase []byte
+}
+
+var errLocked = errors.New("agent: locked")
+
+// NewKeyring returns an Agent that holds keys in memory. It is safe
+// for concurrent use by multiple goroutines.
+func NewKeyring() Agent {
+ return &keyring{}
+}
+
+// RemoveAll removes all identities.
+func (r *keyring) RemoveAll() error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+
+ r.keys = nil
+ return nil
+}
+
+// Remove removes all identities with the given public key.
+func (r *keyring) Remove(key ssh.PublicKey) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+
+ want := key.Marshal()
+ found := false
+ for i := 0; i < len(r.keys); {
+ if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
+ found = true
+ r.keys[i] = r.keys[len(r.keys)-1]
+ r.keys = r.keys[len(r.keys)-1:]
+ continue
+ } else {
+ i++
+ }
+ }
+
+ if !found {
+ return errors.New("agent: key not found")
+ }
+ return nil
+}
+
+// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
+func (r *keyring) Lock(passphrase []byte) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+
+ r.locked = true
+ r.passphrase = passphrase
+ return nil
+}
+
+// Unlock undoes the effect of Lock
+func (r *keyring) Unlock(passphrase []byte) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if !r.locked {
+ return errors.New("agent: not locked")
+ }
+ if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
+ return fmt.Errorf("agent: incorrect passphrase")
+ }
+
+ r.locked = false
+ r.passphrase = nil
+ return nil
+}
+
+// List returns the identities known to the agent.
+func (r *keyring) List() ([]*Key, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ // section 2.7: locked agents return empty.
+ return nil, nil
+ }
+
+ var ids []*Key
+ for _, k := range r.keys {
+ pub := k.signer.PublicKey()
+ ids = append(ids, &Key{
+ Format: pub.Type(),
+ Blob: pub.Marshal(),
+ Comment: k.comment})
+ }
+ return ids, nil
+}
+
+// Insert adds a private key to the keyring. If a certificate
+// is given, that certificate is added as public key.
+func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return errLocked
+ }
+ signer, err := ssh.NewSignerFromKey(priv)
+
+ if err != nil {
+ return err
+ }
+
+ if cert != nil {
+ signer, err = ssh.NewCertSigner(cert, signer)
+ if err != nil {
+ return err
+ }
+ }
+
+ r.keys = append(r.keys, privKey{signer, comment})
+
+ return nil
+}
+
+// Sign returns a signature for the data.
+func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return nil, errLocked
+ }
+
+ wanted := key.Marshal()
+ for _, k := range r.keys {
+ if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
+ return k.signer.Sign(rand.Reader, data)
+ }
+ }
+ return nil, errors.New("not found")
+}
+
+// Signers returns signers for all the known keys.
+func (r *keyring) Signers() ([]ssh.Signer, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.locked {
+ return nil, errLocked
+ }
+
+ s := make([]ssh.Signer, len(r.keys))
+ for _, k := range r.keys {
+ s = append(s, k.signer)
+ }
+ return s, nil
+}
diff --git a/ssh/agent/server.go b/ssh/agent/server.go
new file mode 100644
index 0000000..2d55dc9
--- /dev/null
+++ b/ssh/agent/server.go
@@ -0,0 +1,209 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "crypto/rsa"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "log"
+ "math/big"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+// Server wraps an Agent and uses it to implement the agent side of
+// the SSH-agent, wire protocol.
+type server struct {
+ agent Agent
+}
+
+func (s *server) processRequestBytes(reqData []byte) []byte {
+ rep, err := s.processRequest(reqData)
+ if err != nil {
+ if err != errLocked {
+ // TODO(hanwen): provide better logging interface?
+ log.Printf("agent %d: %v", reqData[0], err)
+ }
+ return []byte{agentFailure}
+ }
+
+ if err == nil && rep == nil {
+ return []byte{agentSuccess}
+ }
+
+ return ssh.Marshal(rep)
+}
+
+func marshalKey(k *Key) []byte {
+ var record struct {
+ Blob []byte
+ Comment string
+ }
+ record.Blob = k.Marshal()
+ record.Comment = k.Comment
+
+ return ssh.Marshal(&record)
+}
+
+type agentV1IdentityMsg struct {
+ Numkeys uint32 `sshtype:"2"`
+}
+
+type agentRemoveIdentityMsg struct {
+ KeyBlob []byte `sshtype:"18"`
+}
+
+type agentLockMsg struct {
+ Passphrase []byte `sshtype:"22"`
+}
+
+type agentUnlockMsg struct {
+ Passphrase []byte `sshtype:"23"`
+}
+
+func (s *server) processRequest(data []byte) (interface{}, error) {
+ switch data[0] {
+ case agentRequestV1Identities:
+ return &agentV1IdentityMsg{0}, nil
+ case agentRemoveIdentity:
+ var req agentRemoveIdentityMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+
+ var wk wireKey
+ if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
+ return nil, err
+ }
+
+ return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
+
+ case agentRemoveAllIdentities:
+ return nil, s.agent.RemoveAll()
+
+ case agentLock:
+ var req agentLockMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+
+ return nil, s.agent.Lock(req.Passphrase)
+
+ case agentUnlock:
+ var req agentLockMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+ return nil, s.agent.Unlock(req.Passphrase)
+
+ case agentSignRequest:
+ var req signRequestAgentMsg
+ if err := ssh.Unmarshal(data, &req); err != nil {
+ return nil, err
+ }
+
+ var wk wireKey
+ if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
+ return nil, err
+ }
+
+ k := &Key{
+ Format: wk.Format,
+ Blob: req.KeyBlob,
+ }
+
+ sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
+ if err != nil {
+ return nil, err
+ }
+ return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
+ case agentRequestIdentities:
+ keys, err := s.agent.List()
+ if err != nil {
+ return nil, err
+ }
+
+ rep := identitiesAnswerAgentMsg{
+ NumKeys: uint32(len(keys)),
+ }
+ for _, k := range keys {
+ rep.Keys = append(rep.Keys, marshalKey(k)...)
+ }
+ return rep, nil
+ case agentAddIdentity:
+ return nil, s.insertIdentity(data)
+ }
+
+ return nil, fmt.Errorf("unknown opcode %d", data[0])
+}
+
+func (s *server) insertIdentity(req []byte) error {
+ var record struct {
+ Type string `sshtype:"17"`
+ Rest []byte `ssh:"rest"`
+ }
+ if err := ssh.Unmarshal(req, &record); err != nil {
+ return err
+ }
+
+ switch record.Type {
+ case ssh.KeyAlgoRSA:
+ var k rsaKeyMsg
+ if err := ssh.Unmarshal(req, &k); err != nil {
+ return err
+ }
+
+ priv := rsa.PrivateKey{
+ PublicKey: rsa.PublicKey{
+ E: int(k.E.Int64()),
+ N: k.N,
+ },
+ D: k.D,
+ Primes: []*big.Int{k.P, k.Q},
+ }
+ priv.Precompute()
+
+ return s.agent.Add(&priv, nil, k.Comments)
+ }
+ return fmt.Errorf("not implemented: %s", record.Type)
+}
+
+// ServeAgent serves the agent protocol on the given connection. It
+// returns when an I/O error occurs.
+func ServeAgent(agent Agent, c io.ReadWriter) error {
+ s := &server{agent}
+
+ var length [4]byte
+ for {
+ if _, err := io.ReadFull(c, length[:]); err != nil {
+ return err
+ }
+ l := binary.BigEndian.Uint32(length[:])
+ if l > maxAgentResponseBytes {
+ // We also cap requests.
+ return fmt.Errorf("agent: request too large: %d", l)
+ }
+
+ req := make([]byte, l)
+ if _, err := io.ReadFull(c, req); err != nil {
+ return err
+ }
+
+ repData := s.processRequestBytes(req)
+ if len(repData) > maxAgentResponseBytes {
+ return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
+ }
+
+ binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
+ if _, err := c.Write(length[:]); err != nil {
+ return err
+ }
+ if _, err := c.Write(repData); err != nil {
+ return err
+ }
+ }
+}
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
new file mode 100644
index 0000000..ad2996b
--- /dev/null
+++ b/ssh/agent/server_test.go
@@ -0,0 +1,77 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package agent
+
+import (
+ "testing"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+func TestServer(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+ client := NewClient(c1)
+
+ go ServeAgent(NewKeyring(), c2)
+
+ testAgentInterface(t, client, testPrivateKeys["rsa"], nil)
+}
+
+func TestLockServer(t *testing.T) {
+ testLockAgent(NewKeyring(), t)
+}
+
+func TestSetupForwardAgent(t *testing.T) {
+ a, b, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+
+ defer a.Close()
+ defer b.Close()
+
+ _, socket, cleanup := startAgent(t)
+ defer cleanup()
+
+ serverConf := ssh.ServerConfig{
+ NoClientAuth: true,
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+ incoming := make(chan *ssh.ServerConn, 1)
+ go func() {
+ conn, _, _, err := ssh.NewServerConn(a, &serverConf)
+ if err != nil {
+ t.Fatalf("Server: %v", err)
+ }
+ incoming <- conn
+ }()
+
+ conf := ssh.ClientConfig{}
+ conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
+ if err != nil {
+ t.Fatalf("NewClientConn: %v", err)
+ }
+ client := ssh.NewClient(conn, chans, reqs)
+
+ if err := ForwardToRemote(client, socket); err != nil {
+ t.Fatalf("SetupForwardAgent: %v", err)
+ }
+
+ server := <-incoming
+ ch, reqs, err := server.OpenChannel(channelType, nil)
+ if err != nil {
+ t.Fatalf("OpenChannel(%q): %v", channelType, err)
+ }
+ go ssh.DiscardRequests(reqs)
+
+ agentClient := NewClient(ch)
+ testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil)
+ conn.Close()
+}
diff --git a/ssh/agent/testdata_test.go b/ssh/agent/testdata_test.go
new file mode 100644
index 0000000..6bb75a9
--- /dev/null
+++ b/ssh/agent/testdata_test.go
@@ -0,0 +1,64 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
+// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
+// instances.
+
+package agent
+
+import (
+ "crypto/rand"
+ "fmt"
+
+ "code.google.com/p/go.crypto/ssh"
+ "code.google.com/p/go.crypto/ssh/testdata"
+)
+
+var (
+ testPrivateKeys map[string]interface{}
+ testSigners map[string]ssh.Signer
+ testPublicKeys map[string]ssh.PublicKey
+)
+
+func init() {
+ var err error
+
+ n := len(testdata.PEMBytes)
+ testPrivateKeys = make(map[string]interface{}, n)
+ testSigners = make(map[string]ssh.Signer, n)
+ testPublicKeys = make(map[string]ssh.PublicKey, n)
+ for t, k := range testdata.PEMBytes {
+ testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
+ }
+ testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
+ }
+ testPublicKeys[t] = testSigners[t].PublicKey()
+ }
+
+ // Create a cert and sign it for use in tests.
+ testCert := &ssh.Certificate{
+ Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+ ValidAfter: 0, // unix epoch
+ ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
+ Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ Key: testPublicKeys["ecdsa"],
+ SignatureKey: testPublicKeys["rsa"],
+ Permissions: ssh.Permissions{
+ CriticalOptions: map[string]string{},
+ Extensions: map[string]string{},
+ },
+ }
+ testCert.SignCert(rand.Reader, testSigners["rsa"])
+ testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
+ testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
+ }
+}
diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go
new file mode 100644
index 0000000..d9f7eb9
--- /dev/null
+++ b/ssh/benchmark_test.go
@@ -0,0 +1,122 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "errors"
+ "io"
+ "net"
+ "testing"
+)
+
+type server struct {
+ *ServerConn
+ chans <-chan NewChannel
+}
+
+func newServer(c net.Conn, conf *ServerConfig) (*server, error) {
+ sconn, chans, reqs, err := NewServerConn(c, conf)
+ if err != nil {
+ return nil, err
+ }
+ go DiscardRequests(reqs)
+ return &server{sconn, chans}, nil
+}
+
+func (s *server) Accept() (NewChannel, error) {
+ n, ok := <-s.chans
+ if !ok {
+ return nil, io.EOF
+ }
+ return n, nil
+}
+
+func sshPipe() (Conn, *server, error) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ clientConf := ClientConfig{
+ User: "user",
+ }
+ serverConf := ServerConfig{
+ NoClientAuth: true,
+ }
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ done := make(chan *server, 1)
+ go func() {
+ server, err := newServer(c2, &serverConf)
+ if err != nil {
+ done <- nil
+ }
+ done <- server
+ }()
+
+ client, _, reqs, err := NewClientConn(c1, "", &clientConf)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ server := <-done
+ if server == nil {
+ return nil, nil, errors.New("server handshake failed.")
+ }
+ go DiscardRequests(reqs)
+
+ return client, server, nil
+}
+
+func BenchmarkEndToEnd(b *testing.B) {
+ b.StopTimer()
+
+ client, server, err := sshPipe()
+ if err != nil {
+ b.Fatalf("sshPipe: %v", err)
+ }
+
+ defer client.Close()
+ defer server.Close()
+
+ size := (1 << 20)
+ input := make([]byte, size)
+ output := make([]byte, size)
+ b.SetBytes(int64(size))
+ done := make(chan int, 1)
+
+ go func() {
+ newCh, err := server.Accept()
+ if err != nil {
+ b.Fatalf("Client: %v", err)
+ }
+ ch, incoming, err := newCh.Accept()
+ go DiscardRequests(incoming)
+ for i := 0; i < b.N; i++ {
+ if _, err := io.ReadFull(ch, output); err != nil {
+ b.Fatalf("ReadFull: %v", err)
+ }
+ }
+ ch.Close()
+ done <- 1
+ }()
+
+ ch, in, err := client.OpenChannel("speed", nil)
+ if err != nil {
+ b.Fatalf("OpenChannel: %v", err)
+ }
+ go DiscardRequests(in)
+
+ b.ResetTimer()
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ if _, err := ch.Write(input); err != nil {
+ b.Fatalf("WriteFull: %v", err)
+ }
+ }
+ ch.Close()
+ b.StopTimer()
+
+ <-done
+}
diff --git a/ssh/buffer.go b/ssh/buffer.go
index 601dad3..6931b51 100644
--- a/ssh/buffer.go
+++ b/ssh/buffer.go
@@ -43,29 +43,29 @@
// buf must not be modified after the call to write.
func (b *buffer) write(buf []byte) {
b.Cond.L.Lock()
- defer b.Cond.L.Unlock()
e := &element{buf: buf}
b.tail.next = e
b.tail = e
b.Cond.Signal()
+ b.Cond.L.Unlock()
}
// eof closes the buffer. Reads from the buffer once all
// the data has been consumed will receive os.EOF.
func (b *buffer) eof() error {
b.Cond.L.Lock()
- defer b.Cond.L.Unlock()
b.closed = true
b.Cond.Signal()
+ b.Cond.L.Unlock()
return nil
}
-// Read reads data from the internal buffer in buf.
-// Reads will block if no data is available, or until
-// the buffer is closed.
+// Read reads data from the internal buffer in buf. Reads will block
+// if no data is available, or until the buffer is closed.
func (b *buffer) Read(buf []byte) (n int, err error) {
b.Cond.L.Lock()
defer b.Cond.L.Unlock()
+
for len(buf) > 0 {
// if there is data in b.head, copy it
if len(b.head.buf) > 0 {
@@ -79,10 +79,12 @@
b.head = b.head.next
continue
}
+
// if at least one byte has been copied, return
if n > 0 {
break
}
+
// if nothing was read, and there is nothing outstanding
// check to see if the buffer is closed.
if b.closed {
diff --git a/ssh/buffer_test.go b/ssh/buffer_test.go
index 135c4ae..d5781cb 100644
--- a/ssh/buffer_test.go
+++ b/ssh/buffer_test.go
@@ -9,33 +9,33 @@
"testing"
)
-var BYTES = []byte("abcdefghijklmnopqrstuvwxyz")
+var alphabet = []byte("abcdefghijklmnopqrstuvwxyz")
func TestBufferReadwrite(t *testing.T) {
b := newBuffer()
- b.write(BYTES[:10])
+ b.write(alphabet[:10])
r, _ := b.Read(make([]byte, 10))
if r != 10 {
t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
}
b = newBuffer()
- b.write(BYTES[:5])
+ b.write(alphabet[:5])
r, _ = b.Read(make([]byte, 10))
if r != 5 {
t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
}
b = newBuffer()
- b.write(BYTES[:10])
+ b.write(alphabet[:10])
r, _ = b.Read(make([]byte, 5))
if r != 5 {
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
}
b = newBuffer()
- b.write(BYTES[:5])
- b.write(BYTES[5:15])
+ b.write(alphabet[:5])
+ b.write(alphabet[5:15])
r, _ = b.Read(make([]byte, 10))
r2, _ := b.Read(make([]byte, 10))
if r != 10 || r2 != 5 || 15 != r+r2 {
@@ -45,14 +45,14 @@
func TestBufferClose(t *testing.T) {
b := newBuffer()
- b.write(BYTES[:10])
+ b.write(alphabet[:10])
b.eof()
_, err := b.Read(make([]byte, 5))
if err != nil {
t.Fatal("expected read of 5 to not return EOF")
}
b = newBuffer()
- b.write(BYTES[:10])
+ b.write(alphabet[:10])
b.eof()
r, err := b.Read(make([]byte, 5))
r2, err2 := b.Read(make([]byte, 10))
@@ -61,7 +61,7 @@
}
b = newBuffer()
- b.write(BYTES[:10])
+ b.write(alphabet[:10])
b.eof()
r, err = b.Read(make([]byte, 5))
r2, err2 = b.Read(make([]byte, 10))
diff --git a/ssh/certs.go b/ssh/certs.go
index d958f31..9962ff0 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -5,6 +5,12 @@
package ssh
import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sort"
"time"
)
@@ -18,67 +24,348 @@
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
)
-// Certificate types are used to specify whether a certificate is for identification
-// of a user or a host. Current identities are defined in [PROTOCOL.certkeys].
+// Certificate types distinguish between host and user
+// certificates. The values can be set in the CertType field of
+// Certificate.
const (
UserCert = 1
HostCert = 2
)
-type signature struct {
+// Signature represents a cryptographic signature.
+type Signature struct {
Format string
Blob []byte
}
-type tuple struct {
- Name string
- Data string
-}
+// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
+// a certificate does not expire.
+const CertTimeInfinity = 1<<64 - 1
-const (
- maxUint64 = 1<<64 - 1
- maxInt64 = 1<<63 - 1
-)
-
-// CertTime represents an unsigned 64-bit time value in seconds starting from
-// UNIX epoch. We use CertTime instead of time.Time in order to properly handle
-// the "infinite" time value ^0, which would become negative when expressed as
-// an int64.
-type CertTime uint64
-
-func (ct CertTime) Time() time.Time {
- if ct > maxInt64 {
- return time.Unix(maxInt64, 0)
- }
- return time.Unix(int64(ct), 0)
-}
-
-func (ct CertTime) IsInfinite() bool {
- return ct == maxUint64
-}
-
-// An OpenSSHCertV01 represents an OpenSSH certificate as defined in
+// An Certificate represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8.
-type OpenSSHCertV01 struct {
- Nonce []byte
- Key PublicKey
- Serial uint64
- Type uint32
- KeyId string
- ValidPrincipals []string
- ValidAfter, ValidBefore CertTime
- CriticalOptions []tuple
- Extensions []tuple
- Reserved []byte
- SignatureKey PublicKey
- Signature *signature
+type Certificate struct {
+ Nonce []byte
+ Key PublicKey
+ Serial uint64
+ CertType uint32
+ KeyId string
+ ValidPrincipals []string
+ ValidAfter uint64
+ ValidBefore uint64
+ Permissions
+ Reserved []byte
+ SignatureKey PublicKey
+ Signature *Signature
}
-// validateOpenSSHCertV01Signature uses the cert's SignatureKey to verify that
-// the cert's Signature.Blob is the result of signing the cert bytes starting
-// from the algorithm string and going up to and including the SignatureKey.
-func validateOpenSSHCertV01Signature(cert *OpenSSHCertV01) bool {
- return cert.SignatureKey.Verify(cert.BytesForSigning(), cert.Signature.Blob)
+// genericCertData holds the key-independent part of the certificate data.
+// Overall, certificates contain an nonce, public key fields and
+// key-independent fields.
+type genericCertData struct {
+ Serial uint64
+ CertType uint32
+ KeyId string
+ ValidPrincipals []byte
+ ValidAfter uint64
+ ValidBefore uint64
+ CriticalOptions []byte
+ Extensions []byte
+ Reserved []byte
+ SignatureKey []byte
+ Signature []byte
+}
+
+func marshalStringList(namelist []string) []byte {
+ var to []byte
+ for _, name := range namelist {
+ s := struct{ N string }{name}
+ to = append(to, Marshal(&s)...)
+ }
+ return to
+}
+
+func marshalTuples(tups map[string]string) []byte {
+ keys := make([]string, 0, len(tups))
+ for k := range tups {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ var r []byte
+ for _, k := range keys {
+ s := struct{ K, V string }{k, tups[k]}
+ r = append(r, Marshal(&s)...)
+ }
+ return r
+}
+
+func parseTuples(in []byte) (map[string]string, error) {
+ tups := map[string]string{}
+ var lastKey string
+ var haveLastKey bool
+
+ for len(in) > 0 {
+ nameBytes, rest, ok := parseString(in)
+ if !ok {
+ return nil, errShortRead
+ }
+ data, rest, ok := parseString(rest)
+ if !ok {
+ return nil, errShortRead
+ }
+ name := string(nameBytes)
+
+ // according to [PROTOCOL.certkeys], the names must be in
+ // lexical order.
+ if haveLastKey && name <= lastKey {
+ return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
+ }
+ lastKey, haveLastKey = name, true
+
+ tups[name] = string(data)
+ in = rest
+ }
+ return tups, nil
+}
+
+func parseCert(in []byte, privAlgo string) (*Certificate, error) {
+ nonce, rest, ok := parseString(in)
+ if !ok {
+ return nil, errShortRead
+ }
+
+ key, rest, err := parsePubKey(rest, privAlgo)
+ if err != nil {
+ return nil, err
+ }
+
+ var g genericCertData
+ if err := Unmarshal(rest, &g); err != nil {
+ return nil, err
+ }
+
+ c := &Certificate{
+ Nonce: nonce,
+ Key: key,
+ Serial: g.Serial,
+ CertType: g.CertType,
+ KeyId: g.KeyId,
+ ValidAfter: g.ValidAfter,
+ ValidBefore: g.ValidBefore,
+ }
+
+ for principals := g.ValidPrincipals; len(principals) > 0; {
+ principal, rest, ok := parseString(principals)
+ if !ok {
+ return nil, errShortRead
+ }
+ c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
+ principals = rest
+ }
+
+ c.CriticalOptions, err = parseTuples(g.CriticalOptions)
+ if err != nil {
+ return nil, err
+ }
+ c.Extensions, err = parseTuples(g.Extensions)
+ if err != nil {
+ return nil, err
+ }
+ c.Reserved = g.Reserved
+ k, err := ParsePublicKey(g.SignatureKey)
+ if err != nil {
+ return nil, err
+ }
+
+ c.SignatureKey = k
+ c.Signature, rest, ok = parseSignatureBody(g.Signature)
+ if !ok || len(rest) > 0 {
+ return nil, errors.New("ssh: signature parse error")
+ }
+
+ return c, nil
+}
+
+type openSSHCertSigner struct {
+ pub *Certificate
+ signer Signer
+}
+
+// NewCertSigner returns a Signer that signs with the given Certificate, whose
+// private key is held by signer. It returns an error if the public key in cert
+// doesn't match the key used by signer.
+func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
+ if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 {
+ return nil, errors.New("ssh: signer and cert have different public key")
+ }
+
+ return &openSSHCertSigner{cert, signer}, nil
+}
+
+func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
+ return s.signer.Sign(rand, data)
+}
+
+func (s *openSSHCertSigner) PublicKey() PublicKey {
+ return s.pub
+}
+
+const sourceAddressCriticalOption = "source-address"
+
+// CertChecker does the work of verifying a certificate. Its methods
+// can be plugged into ClientConfig.HostKeyCallback and
+// ServerConfig.PublicKeyCallback. For the CertChecker to work,
+// minimally, the IsAuthority callback should be set.
+type CertChecker struct {
+ // SupportedCriticalOptions lists the CriticalOptions that the
+ // server application layer understands. These are only used
+ // for user certificates.
+ SupportedCriticalOptions []string
+
+ // IsAuthority should return true if the key is recognized as
+ // an authority. This allows for certificates to be signed by other
+ // certificates.
+ IsAuthority func(auth PublicKey) bool
+
+ // Clock is used for verifying time stamps. If nil, time.Now
+ // is used.
+ Clock func() time.Time
+
+ // UserKeyFallback is called when CertChecker.Authenticate encounters a
+ // public key that is not a certificate. It must implement validation
+ // of user keys or else, if nil, all such keys are rejected.
+ UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
+
+ // HostKeyFallback is called when CertChecker.CheckHostKey encounters a
+ // public key that is not a certificate. It must implement host key
+ // validation or else, if nil, all such keys are rejected.
+ HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
+
+ // IsRevoked is called for each certificate so that revocation checking
+ // can be implemented. It should return true if the given certificate
+ // is revoked and false otherwise. If nil, no certificates are
+ // considered to have been revoked.
+ IsRevoked func(cert *Certificate) bool
+}
+
+// CheckHostKey checks a host key certificate. This method can be
+// plugged into ClientConfig.HostKeyCallback.
+func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
+ cert, ok := key.(*Certificate)
+ if !ok {
+ if c.HostKeyFallback != nil {
+ return c.HostKeyFallback(addr, remote, key)
+ }
+ return errors.New("ssh: non-certificate host key")
+ }
+ if cert.CertType != HostCert {
+ return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
+ }
+
+ return c.CheckCert(addr, cert)
+}
+
+// Authenticate checks a user certificate. Authenticate can be used as
+// a value for ServerConfig.PublicKeyCallback.
+func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
+ cert, ok := pubKey.(*Certificate)
+ if !ok {
+ if c.UserKeyFallback != nil {
+ return c.UserKeyFallback(conn, pubKey)
+ }
+ return nil, errors.New("ssh: normal key pairs not accepted")
+ }
+
+ if cert.CertType != UserCert {
+ return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
+ }
+
+ if err := c.CheckCert(conn.User(), cert); err != nil {
+ return nil, err
+ }
+
+ return &cert.Permissions, nil
+}
+
+// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
+// the signature of the certificate.
+func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
+ if c.IsRevoked != nil && c.IsRevoked(cert) {
+ return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
+ }
+
+ for opt, _ := range cert.CriticalOptions {
+ // sourceAddressCriticalOption will be enforced by
+ // serverAuthenticate
+ if opt == sourceAddressCriticalOption {
+ continue
+ }
+
+ found := false
+ for _, supp := range c.SupportedCriticalOptions {
+ if supp == opt {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
+ }
+ }
+
+ if len(cert.ValidPrincipals) > 0 {
+ // By default, certs are valid for all users/hosts.
+ found := false
+ for _, p := range cert.ValidPrincipals {
+ if p == principal {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
+ }
+ }
+
+ if !c.IsAuthority(cert.SignatureKey) {
+ return fmt.Errorf("ssh: certificate signed by unrecognized authority")
+ }
+
+ clock := c.Clock
+ if clock == nil {
+ clock = time.Now
+ }
+
+ unixNow := clock().Unix()
+ if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
+ return fmt.Errorf("ssh: cert is not yet valid")
+ }
+ if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) {
+ return fmt.Errorf("ssh: cert has expired")
+ }
+ if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
+ return fmt.Errorf("ssh: certificate signature does not verify")
+ }
+
+ return nil
+}
+
+// SignCert sets c.SignatureKey to the authority's public key and stores a
+// Signature, by authority, in the certificate.
+func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
+ c.Nonce = make([]byte, 32)
+ if _, err := io.ReadFull(rand, c.Nonce); err != nil {
+ return err
+ }
+ c.SignatureKey = authority.PublicKey()
+
+ sig, err := authority.Sign(rand, c.bytesForSigning())
+ if err != nil {
+ return err
+ }
+ c.Signature = sig
+ return nil
}
var certAlgoNames = map[string]string{
@@ -100,260 +387,69 @@
panic("unknown cert algorithm")
}
-func (cert *OpenSSHCertV01) marshal(includeAlgo, includeSig bool) []byte {
- algoName := cert.PublicKeyAlgo()
- pubKey := cert.Key.Marshal()
- sigKey := MarshalPublicKey(cert.SignatureKey)
-
- var length int
- if includeAlgo {
- length += stringLength(len(algoName))
- }
- length += stringLength(len(cert.Nonce))
- length += len(pubKey)
- length += 8 // Length of Serial
- length += 4 // Length of Type
- length += stringLength(len(cert.KeyId))
- length += lengthPrefixedNameListLength(cert.ValidPrincipals)
- length += 8 // Length of ValidAfter
- length += 8 // Length of ValidBefore
- length += tupleListLength(cert.CriticalOptions)
- length += tupleListLength(cert.Extensions)
- length += stringLength(len(cert.Reserved))
- length += stringLength(len(sigKey))
- if includeSig {
- length += signatureLength(cert.Signature)
- }
-
- ret := make([]byte, length)
- r := ret
- if includeAlgo {
- r = marshalString(r, []byte(algoName))
- }
- r = marshalString(r, cert.Nonce)
- copy(r, pubKey)
- r = r[len(pubKey):]
- r = marshalUint64(r, cert.Serial)
- r = marshalUint32(r, cert.Type)
- r = marshalString(r, []byte(cert.KeyId))
- r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
- r = marshalUint64(r, uint64(cert.ValidAfter))
- r = marshalUint64(r, uint64(cert.ValidBefore))
- r = marshalTupleList(r, cert.CriticalOptions)
- r = marshalTupleList(r, cert.Extensions)
- r = marshalString(r, cert.Reserved)
- r = marshalString(r, sigKey)
- if includeSig {
- r = marshalSignature(r, cert.Signature)
- }
- if len(r) > 0 {
- panic("ssh: internal error, marshaling certificate did not fill the entire buffer")
- }
- return ret
+func (cert *Certificate) bytesForSigning() []byte {
+ c2 := *cert
+ c2.Signature = nil
+ out := c2.Marshal()
+ // Drop trailing signature length.
+ return out[:len(out)-4]
}
-func (cert *OpenSSHCertV01) BytesForSigning() []byte {
- return cert.marshal(true, false)
+// Marshal serializes c into OpenSSH's wire format. It is part of the
+// PublicKey interface.
+func (c *Certificate) Marshal() []byte {
+ generic := genericCertData{
+ Serial: c.Serial,
+ CertType: c.CertType,
+ KeyId: c.KeyId,
+ ValidPrincipals: marshalStringList(c.ValidPrincipals),
+ ValidAfter: uint64(c.ValidAfter),
+ ValidBefore: uint64(c.ValidBefore),
+ CriticalOptions: marshalTuples(c.CriticalOptions),
+ Extensions: marshalTuples(c.Extensions),
+ Reserved: c.Reserved,
+ SignatureKey: c.SignatureKey.Marshal(),
+ }
+ if c.Signature != nil {
+ generic.Signature = Marshal(c.Signature)
+ }
+ genericBytes := Marshal(&generic)
+ keyBytes := c.Key.Marshal()
+ _, keyBytes, _ = parseString(keyBytes)
+ prefix := Marshal(&struct {
+ Name string
+ Nonce []byte
+ Key []byte `ssh:"rest"`
+ }{c.Type(), c.Nonce, keyBytes})
+
+ result := make([]byte, 0, len(prefix)+len(genericBytes))
+ result = append(result, prefix...)
+ result = append(result, genericBytes...)
+ return result
}
-func (cert *OpenSSHCertV01) Marshal() []byte {
- return cert.marshal(false, true)
-}
-
-func (c *OpenSSHCertV01) PublicKeyAlgo() string {
- algo, ok := certAlgoNames[c.Key.PublicKeyAlgo()]
+// Type returns the key name. It is part of the PublicKey interface.
+func (c *Certificate) Type() string {
+ algo, ok := certAlgoNames[c.Key.Type()]
if !ok {
panic("unknown cert key type")
}
return algo
}
-func (c *OpenSSHCertV01) PrivateKeyAlgo() string {
- return c.Key.PrivateKeyAlgo()
-}
-
-func (c *OpenSSHCertV01) Verify(data []byte, sig []byte) bool {
+// Verify verifies a signature against the certificate's public
+// key. It is part of the PublicKey interface.
+func (c *Certificate) Verify(data []byte, sig *Signature) error {
return c.Key.Verify(data, sig)
}
-func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) {
- cert := new(OpenSSHCertV01)
-
- if cert.Nonce, in, ok = parseString(in); !ok {
- return
- }
-
- privAlgo := certToPrivAlgo(algo)
- cert.Key, in, ok = parsePubKey(in, privAlgo)
+func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
+ format, in, ok := parseString(in)
if !ok {
return
}
- // We test PublicKeyAlgo to make sure we don't use some weird sub-cert.
- if cert.Key.PublicKeyAlgo() != privAlgo {
- ok = false
- return
- }
-
- if cert.Serial, in, ok = parseUint64(in); !ok {
- return
- }
-
- if cert.Type, in, ok = parseUint32(in); !ok {
- return
- }
-
- keyId, in, ok := parseString(in)
- if !ok {
- return
- }
- cert.KeyId = string(keyId)
-
- if cert.ValidPrincipals, in, ok = parseLengthPrefixedNameList(in); !ok {
- return
- }
-
- va, in, ok := parseUint64(in)
- if !ok {
- return
- }
- cert.ValidAfter = CertTime(va)
-
- vb, in, ok := parseUint64(in)
- if !ok {
- return
- }
- cert.ValidBefore = CertTime(vb)
-
- if cert.CriticalOptions, in, ok = parseTupleList(in); !ok {
- return
- }
-
- if cert.Extensions, in, ok = parseTupleList(in); !ok {
- return
- }
-
- if cert.Reserved, in, ok = parseString(in); !ok {
- return
- }
-
- sigKey, in, ok := parseString(in)
- if !ok {
- return
- }
- if cert.SignatureKey, _, ok = ParsePublicKey(sigKey); !ok {
- return
- }
-
- if cert.Signature, in, ok = parseSignature(in); !ok {
- return
- }
-
- ok = true
- return cert, in, ok
-}
-
-func lengthPrefixedNameListLength(namelist []string) int {
- length := 4 // length prefix for list
- for _, name := range namelist {
- length += 4 // length prefix for name
- length += len(name)
- }
- return length
-}
-
-func marshalLengthPrefixedNameList(to []byte, namelist []string) []byte {
- length := uint32(lengthPrefixedNameListLength(namelist) - 4)
- to = marshalUint32(to, length)
- for _, name := range namelist {
- to = marshalString(to, []byte(name))
- }
- return to
-}
-
-func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool) {
- list, rest, ok := parseString(in)
- if !ok {
- return
- }
-
- for len(list) > 0 {
- var next []byte
- if next, list, ok = parseString(list); !ok {
- return nil, nil, false
- }
- out = append(out, string(next))
- }
- ok = true
- return
-}
-
-func tupleListLength(tupleList []tuple) int {
- length := 4 // length prefix for list
- for _, t := range tupleList {
- length += 4 // length prefix for t.Name
- length += len(t.Name)
- length += 4 // length prefix for t.Data
- length += len(t.Data)
- }
- return length
-}
-
-func marshalTupleList(to []byte, tuplelist []tuple) []byte {
- length := uint32(tupleListLength(tuplelist) - 4)
- to = marshalUint32(to, length)
- for _, t := range tuplelist {
- to = marshalString(to, []byte(t.Name))
- to = marshalString(to, []byte(t.Data))
- }
- return to
-}
-
-func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
- list, rest, ok := parseString(in)
- if !ok {
- return
- }
-
- for len(list) > 0 {
- var name, data []byte
- var ok bool
- name, list, ok = parseString(list)
- if !ok {
- return nil, nil, false
- }
- data, list, ok = parseString(list)
- if !ok {
- return nil, nil, false
- }
- out = append(out, tuple{string(name), string(data)})
- }
- ok = true
- return
-}
-
-func signatureLength(sig *signature) int {
- length := 4 // length prefix for signature
- length += stringLength(len(sig.Format))
- length += stringLength(len(sig.Blob))
- return length
-}
-
-func marshalSignature(to []byte, sig *signature) []byte {
- length := uint32(signatureLength(sig) - 4)
- to = marshalUint32(to, length)
- to = marshalString(to, []byte(sig.Format))
- to = marshalString(to, sig.Blob)
- return to
-}
-
-func parseSignatureBody(in []byte) (out *signature, rest []byte, ok bool) {
- var format []byte
- if format, in, ok = parseString(in); !ok {
- return
- }
-
- out = &signature{
+ out = &Signature{
Format: string(format),
}
@@ -364,14 +460,14 @@
return out, in, ok
}
-func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
- var sigBytes []byte
- if sigBytes, rest, ok = parseString(in); !ok {
+func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
+ sigBytes, rest, ok := parseString(in)
+ if !ok {
return
}
- out, sigBytes, ok = parseSignatureBody(sigBytes)
- if !ok || len(sigBytes) > 0 {
+ out, trailing, ok := parseSignatureBody(sigBytes)
+ if !ok || len(trailing) > 0 {
return nil, nil, false
}
return
diff --git a/ssh/certs_test.go b/ssh/certs_test.go
index 3cec28e..7d1b00f 100644
--- a/ssh/certs_test.go
+++ b/ssh/certs_test.go
@@ -6,7 +6,9 @@
import (
"bytes"
+ "crypto/rand"
"testing"
+ "time"
)
// Cert generated by ssh-keygen 6.0p1 Debian-4.
@@ -16,16 +18,16 @@
func TestParseCert(t *testing.T) {
authKeyBytes := []byte(exampleSSHCert)
- key, _, _, rest, ok := ParseAuthorizedKey(authKeyBytes)
- if !ok {
- t.Fatalf("could not parse certificate")
+ key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
+ if err != nil {
+ t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
- if _, ok = key.(*OpenSSHCertV01); !ok {
- t.Fatalf("got %#v, want *OpenSSHCertV01", key)
+ if _, ok := key.(*Certificate); !ok {
+ t.Fatalf("got %#v, want *Certificate", key)
}
marshaled := MarshalAuthorizedKey(key)
@@ -37,19 +39,118 @@
}
}
-func TestVerifyCert(t *testing.T) {
- key, _, _, _, _ := ParseAuthorizedKey([]byte(exampleSSHCert))
- validCert := key.(*OpenSSHCertV01)
- if ok := validateOpenSSHCertV01Signature(validCert); !ok {
- t.Error("Unable to validate certificate!")
+func TestValidateCert(t *testing.T) {
+ key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
+ if err != nil {
+ t.Fatalf("ParseAuthorizedKey: %v", err)
+ }
+ validCert, ok := key.(*Certificate)
+ if !ok {
+ t.Fatalf("got %v (%T), want *Certificate", key, key)
+ }
+ checker := CertChecker{}
+ checker.IsAuthority = func(k PublicKey) bool {
+ return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
- invalidCert := &OpenSSHCertV01{
- Key: rsaKey.PublicKey(),
- SignatureKey: ecdsaKey.PublicKey(),
- Signature: &signature{},
+ if err := checker.CheckCert("user", validCert); err != nil {
+ t.Errorf("Unable to validate certificate: %v", err)
}
- if ok := validateOpenSSHCertV01Signature(invalidCert); ok {
- t.Error("Invalid cert signature passed validation!")
+ invalidCert := &Certificate{
+ Key: testPublicKeys["rsa"],
+ SignatureKey: testPublicKeys["ecdsa"],
+ ValidBefore: CertTimeInfinity,
+ Signature: &Signature{},
+ }
+ if err := checker.CheckCert("user", invalidCert); err == nil {
+ t.Error("Invalid cert signature passed validation")
+ }
+}
+
+func TestValidateCertTime(t *testing.T) {
+ cert := Certificate{
+ ValidPrincipals: []string{"user"},
+ Key: testPublicKeys["rsa"],
+ ValidAfter: 50,
+ ValidBefore: 100,
+ }
+
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+
+ for ts, ok := range map[int64]bool{
+ 25: false,
+ 50: true,
+ 99: true,
+ 100: false,
+ 125: false,
+ } {
+ checker := CertChecker{
+ Clock: func() time.Time { return time.Unix(ts, 0) },
+ }
+ checker.IsAuthority = func(k PublicKey) bool {
+ return bytes.Equal(k.Marshal(),
+ testPublicKeys["ecdsa"].Marshal())
+ }
+
+ if v := checker.CheckCert("user", &cert); (v == nil) != ok {
+ t.Errorf("Authenticate(%d): %v", ts, v)
+ }
+ }
+}
+
+// TODO(hanwen): tests for
+//
+// host keys:
+// * fallbacks
+
+func TestHostKeyCert(t *testing.T) {
+ cert := &Certificate{
+ ValidPrincipals: []string{"hostname", "hostname.domain"},
+ Key: testPublicKeys["rsa"],
+ ValidBefore: CertTimeInfinity,
+ CertType: HostCert,
+ }
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+
+ checker := &CertChecker{
+ IsAuthority: func(p PublicKey) bool {
+ return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
+ },
+ }
+
+ certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+ if err != nil {
+ t.Errorf("NewCertSigner: %v", err)
+ }
+
+ for _, name := range []string{"hostname", "otherhost"} {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ go func() {
+ conf := ServerConfig{
+ NoClientAuth: true,
+ }
+ conf.AddHostKey(certSigner)
+ _, _, _, err := NewServerConn(c1, &conf)
+ if err != nil {
+ t.Fatalf("NewServerConn: %v", err)
+ }
+ }()
+
+ config := &ClientConfig{
+ User: "user",
+ HostKeyCallback: checker.CheckHostKey,
+ }
+ _, _, _, err = NewClientConn(c2, name, config)
+
+ succeed := name == "hostname"
+ if (err == nil) != succeed {
+ t.Fatalf("NewClientConn(%q): %v", name, err)
+ }
}
}
diff --git a/ssh/channel.go b/ssh/channel.go
index c5413c9..8e777bb 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -5,71 +5,100 @@
package ssh
import (
+ "encoding/binary"
"errors"
"fmt"
"io"
+ "log"
"sync"
- "sync/atomic"
)
-// extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
-// section 5.2.
-type extendedDataTypeCode uint32
-
const (
- // extendedDataStderr is the extended data type that is used for stderr.
- extendedDataStderr extendedDataTypeCode = 1
-
- // minPacketLength defines the smallest valid packet
minPacketLength = 9
-
- // channelMaxPacketSize defines the maximum packet size advertised in open messages
- channelMaxPacketSize = 1 << 15 // RFC 4253 6.1, minimum 32 KiB
-
- // channelWindowSize defines the window size advertised in open messages
- channelWindowSize = 64 * channelMaxPacketSize // Like OpenSSH
+ // channelMaxPacket contains the maximum number of bytes that will be
+ // sent in a single packet. As per RFC 4253, section 6.1, 32k is also
+ // the minimum.
+ channelMaxPacket = 1 << 15
+ // We follow OpenSSH here.
+ channelWindowSize = 64 * channelMaxPacket
)
-// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
-// SSH connection. Channel.Read can return a ChannelRequest as an error.
-type Channel interface {
- // Accept accepts the channel creation request.
- Accept() error
- // Reject rejects the channel creation request. After calling this, no
- // other methods on the Channel may be called. If they are then the
- // peer is likely to signal a protocol error and drop the connection.
+// NewChannel represents an incoming request to a channel. It must either be
+// accepted for use by calling Accept, or rejected by calling Reject.
+type NewChannel interface {
+ // Accept accepts the channel creation request. It returns the Channel
+ // and a Go channel containing SSH requests. The Go channel must be
+ // serviced otherwise the Channel will hang.
+ Accept() (Channel, <-chan *Request, error)
+
+ // Reject rejects the channel creation request. After calling
+ // this, no other methods on the Channel may be called.
Reject(reason RejectionReason, message string) error
- // Read may return a ChannelRequest as an error.
- Read(data []byte) (int, error)
- Write(data []byte) (int, error)
- Close() error
-
- // Stderr returns an io.Writer that writes to this channel with the
- // extended data type set to stderr.
- Stderr() io.Writer
-
- // AckRequest either sends an ack or nack to the channel request.
- AckRequest(ok bool) error
-
// ChannelType returns the type of the channel, as supplied by the
// client.
ChannelType() string
+
// ExtraData returns the arbitrary payload for this channel, as supplied
// by the client. This data is specific to the channel type.
ExtraData() []byte
}
-// ChannelRequest represents a request sent on a channel, outside of the normal
-// stream of bytes. It may result from calling Read on a Channel.
-type ChannelRequest struct {
- Request string
- WantReply bool
- Payload []byte
+// A Channel is an ordered, reliable, flow-controlled, duplex stream
+// that is multiplexed over an SSH connection.
+type Channel interface {
+ // Read reads up to len(data) bytes from the channel.
+ Read(data []byte) (int, error)
+
+ // Write writes len(data) bytes to the channel.
+ Write(data []byte) (int, error)
+
+ // Close signals end of channel use. No data may be sent after this
+ // call.
+ Close() error
+
+ // CloseWrite signals the end of sending in-band
+ // data. Requests may still be sent, and the other side may
+ // still send data
+ CloseWrite() error
+
+ // SendRequest sends a channel request. If wantReply is true,
+ // it will wait for a reply and return the result as a
+ // boolean, otherwise the return value will be false. Channel
+ // requests are out-of-band messages so they may be sent even
+ // if the data stream is closed or blocked by flow control.
+ SendRequest(name string, wantReply bool, payload []byte) (bool, error)
+
+ // Stderr returns an io.ReadWriter that writes to this channel with the
+ // extended data type set to stderr.
+ Stderr() io.ReadWriter
}
-func (c ChannelRequest) Error() string {
- return "ssh: channel request received"
+// Request is a request sent outside of the normal stream of
+// data. Requests can either be specific to an SSH channel, or they
+// can be global.
+type Request struct {
+ Type string
+ WantReply bool
+ Payload []byte
+
+ ch *channel
+ mux *mux
+}
+
+// Reply sends a response to a request. It must be called for all requests
+// where WantReply is true and is a no-op otherwise. The payload argument is
+// ignored for replies to channel-specific requests.
+func (r *Request) Reply(ok bool, payload []byte) error {
+ if !r.WantReply {
+ return nil
+ }
+
+ if r.ch == nil {
+ return r.mux.ackRequest(ok, payload)
+ }
+
+ return r.ch.ackRequest(ok)
}
// RejectionReason is an enumeration used when rejecting channel creation
@@ -98,464 +127,6 @@
return fmt.Sprintf("unknown reason %d", int(r))
}
-type channel struct {
- packetConn // the underlying transport
- localId, remoteId uint32
- remoteWin window
- maxPacket uint32
- isClosed uint32 // atomic bool, non zero if true
-}
-
-func (c *channel) sendWindowAdj(n int) error {
- msg := windowAdjustMsg{
- PeersId: c.remoteId,
- AdditionalBytes: uint32(n),
- }
- return c.writePacket(marshal(msgChannelWindowAdjust, msg))
-}
-
-// sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
-func (c *channel) sendEOF() error {
- return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
- PeersId: c.remoteId,
- }))
-}
-
-// sendClose informs the remote side of our intent to close the channel.
-func (c *channel) sendClose() error {
- return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
- PeersId: c.remoteId,
- }))
-}
-
-func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
- reject := channelOpenFailureMsg{
- PeersId: c.remoteId,
- Reason: reason,
- Message: message,
- Language: "en",
- }
- return c.writePacket(marshal(msgChannelOpenFailure, reject))
-}
-
-func (c *channel) writePacket(b []byte) error {
- if c.closed() {
- return io.EOF
- }
- if uint32(len(b)) > c.maxPacket {
- return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
- }
- return c.packetConn.writePacket(b)
-}
-
-func (c *channel) closed() bool {
- return atomic.LoadUint32(&c.isClosed) > 0
-}
-
-func (c *channel) setClosed() bool {
- return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
-}
-
-type serverChan struct {
- channel
- // immutable once created
- chanType string
- extraData []byte
-
- serverConn *ServerConn
- myWindow uint32
- theyClosed bool // indicates the close msg has been received from the remote side
- theySentEOF bool
- isDead uint32
- err error
-
- pendingRequests []ChannelRequest
- pendingData []byte
- head, length int
-
- // This lock is inferior to serverConn.lock
- cond *sync.Cond
-}
-
-func (c *serverChan) Accept() error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
-
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
-
- confirm := channelOpenConfirmMsg{
- PeersId: c.remoteId,
- MyId: c.localId,
- MyWindow: c.myWindow,
- MaxPacketSize: c.maxPacket,
- }
- return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
-}
-
-func (c *serverChan) Reject(reason RejectionReason, message string) error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
-
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
-
- return c.sendChannelOpenFailure(reason, message)
-}
-
-func (c *serverChan) handlePacket(packet interface{}) {
- c.cond.L.Lock()
- defer c.cond.L.Unlock()
-
- switch packet := packet.(type) {
- case *channelRequestMsg:
- req := ChannelRequest{
- Request: packet.Request,
- WantReply: packet.WantReply,
- Payload: packet.RequestSpecificData,
- }
-
- c.pendingRequests = append(c.pendingRequests, req)
- c.cond.Signal()
- case *channelCloseMsg:
- c.theyClosed = true
- c.cond.Signal()
- case *channelEOFMsg:
- c.theySentEOF = true
- c.cond.Signal()
- case *windowAdjustMsg:
- if !c.remoteWin.add(packet.AdditionalBytes) {
- panic("illegal window update")
- }
- default:
- panic("unknown packet type")
- }
-}
-
-func (c *serverChan) handleData(data []byte) {
- c.cond.L.Lock()
- defer c.cond.L.Unlock()
-
- // The other side should never send us more than our window.
- if len(data)+c.length > len(c.pendingData) {
- // TODO(agl): we should tear down the channel with a protocol
- // error.
- return
- }
-
- c.myWindow -= uint32(len(data))
- for i := 0; i < 2; i++ {
- tail := c.head + c.length
- if tail >= len(c.pendingData) {
- tail -= len(c.pendingData)
- }
- n := copy(c.pendingData[tail:], data)
- data = data[n:]
- c.length += n
- }
-
- c.cond.Signal()
-}
-
-func (c *serverChan) Stderr() io.Writer {
- return extendedDataChannel{c: c, t: extendedDataStderr}
-}
-
-// extendedDataChannel is an io.Writer that writes any data to c as extended
-// data of the given type.
-type extendedDataChannel struct {
- t extendedDataTypeCode
- c *serverChan
-}
-
-func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
- const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
- c := edc.c
- for len(data) > 0 {
- space := min(c.maxPacket-headerLength, len(data))
- if space, err = c.getWindowSpace(space); err != nil {
- return 0, err
- }
- todo := data
- if uint32(len(todo)) > space {
- todo = todo[:space]
- }
-
- packet := make([]byte, headerLength+len(todo))
- packet[0] = msgChannelExtendedData
- marshalUint32(packet[1:], c.remoteId)
- marshalUint32(packet[5:], uint32(edc.t))
- marshalUint32(packet[9:], uint32(len(todo)))
- copy(packet[13:], todo)
-
- if err = c.writePacket(packet); err != nil {
- return
- }
-
- n += len(todo)
- data = data[len(todo):]
- }
-
- return
-}
-
-func (c *serverChan) Read(data []byte) (n int, err error) {
- n, err, windowAdjustment := c.read(data)
-
- if windowAdjustment > 0 {
- packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
- PeersId: c.remoteId,
- AdditionalBytes: windowAdjustment,
- })
- err = c.writePacket(packet)
- }
-
- return
-}
-
-func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
- c.cond.L.Lock()
- defer c.cond.L.Unlock()
-
- if c.err != nil {
- return 0, c.err, 0
- }
-
- for {
- if c.theySentEOF || c.theyClosed || c.dead() {
- return 0, io.EOF, 0
- }
-
- if len(c.pendingRequests) > 0 {
- req := c.pendingRequests[0]
- if len(c.pendingRequests) == 1 {
- c.pendingRequests = nil
- } else {
- oldPendingRequests := c.pendingRequests
- c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
- copy(c.pendingRequests, oldPendingRequests[1:])
- }
-
- return 0, req, 0
- }
-
- if c.length > 0 {
- tail := min(uint32(c.head+c.length), len(c.pendingData))
- n = copy(data, c.pendingData[c.head:tail])
- c.head += n
- c.length -= n
- if c.head == len(c.pendingData) {
- c.head = 0
- }
-
- windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow
- if windowAdjustment < uint32(len(c.pendingData)/2) {
- windowAdjustment = 0
- }
- c.myWindow += windowAdjustment
-
- return
- }
-
- c.cond.Wait()
- }
-
- panic("unreachable")
-}
-
-// getWindowSpace takes, at most, max bytes of space from the peer's window. It
-// returns the number of bytes actually reserved.
-func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
- if c.dead() || c.closed() {
- return 0, io.EOF
- }
- return c.remoteWin.reserve(max), nil
-}
-
-func (c *serverChan) dead() bool {
- return atomic.LoadUint32(&c.isDead) > 0
-}
-
-func (c *serverChan) setDead() {
- atomic.StoreUint32(&c.isDead, 1)
-}
-
-func (c *serverChan) Write(data []byte) (n int, err error) {
- const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
- for len(data) > 0 {
- space := min(c.maxPacket-headerLength, len(data))
- if space, err = c.getWindowSpace(space); err != nil {
- return 0, err
- }
- todo := data
- if uint32(len(todo)) > space {
- todo = todo[:space]
- }
-
- packet := make([]byte, headerLength+len(todo))
- packet[0] = msgChannelData
- marshalUint32(packet[1:], c.remoteId)
- marshalUint32(packet[5:], uint32(len(todo)))
- copy(packet[9:], todo)
-
- if err = c.writePacket(packet); err != nil {
- return
- }
-
- n += len(todo)
- data = data[len(todo):]
- }
-
- return
-}
-
-// Close signals the intent to close the channel.
-func (c *serverChan) Close() error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
-
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
-
- if !c.setClosed() {
- return errors.New("ssh: channel already closed")
- }
- return c.sendClose()
-}
-
-func (c *serverChan) AckRequest(ok bool) error {
- c.serverConn.lock.Lock()
- defer c.serverConn.lock.Unlock()
-
- if c.serverConn.err != nil {
- return c.serverConn.err
- }
-
- if !ok {
- ack := channelRequestFailureMsg{
- PeersId: c.remoteId,
- }
- return c.writePacket(marshal(msgChannelFailure, ack))
- }
-
- ack := channelRequestSuccessMsg{
- PeersId: c.remoteId,
- }
- return c.writePacket(marshal(msgChannelSuccess, ack))
-}
-
-func (c *serverChan) ChannelType() string {
- return c.chanType
-}
-
-func (c *serverChan) ExtraData() []byte {
- return c.extraData
-}
-
-// A clientChan represents a single RFC 4254 channel multiplexed
-// over a SSH connection.
-type clientChan struct {
- channel
- stdin *chanWriter
- stdout *chanReader
- stderr *chanReader
- msg chan interface{}
-}
-
-// newClientChan returns a partially constructed *clientChan
-// using the local id provided. To be usable clientChan.remoteId
-// needs to be assigned once known.
-func newClientChan(cc packetConn, id uint32) *clientChan {
- c := &clientChan{
- channel: channel{
- packetConn: cc,
- localId: id,
- remoteWin: window{Cond: newCond()},
- },
- msg: make(chan interface{}, 16),
- }
- c.stdin = &chanWriter{
- channel: &c.channel,
- }
- c.stdout = &chanReader{
- channel: &c.channel,
- buffer: newBuffer(),
- }
- c.stderr = &chanReader{
- channel: &c.channel,
- buffer: newBuffer(),
- }
- return c
-}
-
-// waitForChannelOpenResponse, if successful, fills out
-// the remoteId and records any initial window advertisement.
-func (c *clientChan) waitForChannelOpenResponse() error {
- switch msg := (<-c.msg).(type) {
- case *channelOpenConfirmMsg:
- if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
- return errors.New("ssh: invalid MaxPacketSize from peer")
- }
- // fixup remoteId field
- c.remoteId = msg.MyId
- c.maxPacket = msg.MaxPacketSize
- c.remoteWin.add(msg.MyWindow)
- return nil
- case *channelOpenFailureMsg:
- return errors.New(safeString(msg.Message))
- }
- return errors.New("ssh: unexpected packet")
-}
-
-// Close signals the intent to close the channel.
-func (c *clientChan) Close() error {
- if !c.setClosed() {
- return errors.New("ssh: channel already closed")
- }
- c.stdout.eof()
- c.stderr.eof()
- return c.sendClose()
-}
-
-// A chanWriter represents the stdin of a remote process.
-type chanWriter struct {
- *channel
- // indicates the writer has been closed. eof is owned by the
- // caller of Write/Close.
- eof bool
-}
-
-// Write writes data to the remote process's standard input.
-func (w *chanWriter) Write(data []byte) (written int, err error) {
- const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
- for len(data) > 0 {
- if w.eof || w.closed() {
- err = io.EOF
- return
- }
- // never send more data than maxPacket even if
- // there is sufficient window.
- n := min(w.maxPacket-headerLength, len(data))
- r := w.remoteWin.reserve(n)
- n = r
- remoteId := w.remoteId
- packet := []byte{
- msgChannelData,
- byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
- byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
- }
- if err = w.writePacket(append(packet, data[:n]...)); err != nil {
- break
- }
- data = data[n:]
- written += int(n)
- }
- return
-}
-
func min(a uint32, b int) uint32 {
if a < uint32(b) {
return a
@@ -563,32 +134,475 @@
return uint32(b)
}
-func (w *chanWriter) Close() error {
- w.eof = true
- return w.sendEOF()
+type channelDirection uint8
+
+const (
+ channelInbound channelDirection = iota
+ channelOutbound
+)
+
+// channel is an implementation of the Channel interface that works
+// with the mux class.
+type channel struct {
+ // R/O after creation
+ chanType string
+ extraData []byte
+ localId, remoteId uint32
+
+ // maxIncomingPayload and maxRemotePayload are the maximum
+ // payload sizes of normal and extended data packets for
+ // receiving and sending, respectively. The wire packet will
+ // be 9 or 13 bytes larger (excluding encryption overhead).
+ maxIncomingPayload uint32
+ maxRemotePayload uint32
+
+ mux *mux
+
+ // decided is set to true if an accept or reject message has been sent
+ // (for outbound channels) or received (for inbound channels).
+ decided bool
+
+ // direction contains either channelOutbound, for channels created
+ // locally, or channelInbound, for channels created by the peer.
+ direction channelDirection
+
+ // Pending internal channel messages.
+ msg chan interface{}
+
+ // Since requests have no ID, there can be only one request
+ // with WantReply=true outstanding. This lock is held by a
+ // goroutine that has such an outgoing request pending.
+ sentRequestMu sync.Mutex
+
+ incomingRequests chan *Request
+
+ sentEOF bool
+
+ // thread-safe data
+ remoteWin window
+ pending *buffer
+ extPending *buffer
+
+ // windowMu protects myWindow, the flow-control window.
+ windowMu sync.Mutex
+ myWindow uint32
+
+ // writeMu serializes calls to mux.conn.writePacket() and
+ // protects sentClose. This mutex must be different from
+ // windowMu, as writePacket can block if there is a key
+ // exchange pending
+ writeMu sync.Mutex
+ sentClose bool
}
-// A chanReader represents stdout or stderr of a remote process.
-type chanReader struct {
- *channel // the channel backing this reader
- *buffer
+// writePacket sends a packet. If the packet is a channel close, it updates
+// sentClose. This method takes the lock c.writeMu.
+func (c *channel) writePacket(packet []byte) error {
+ c.writeMu.Lock()
+ if c.sentClose {
+ c.writeMu.Unlock()
+ return io.EOF
+ }
+ c.sentClose = (packet[0] == msgChannelClose)
+ err := c.mux.conn.writePacket(packet)
+ c.writeMu.Unlock()
+ return err
}
-// Read reads data from the remote process's stdout or stderr.
-func (r *chanReader) Read(buf []byte) (int, error) {
- n, err := r.buffer.Read(buf)
- if err != nil {
- if err == io.EOF {
+func (c *channel) sendMessage(msg interface{}) error {
+ if debugMux {
+ log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
+ }
+
+ p := Marshal(msg)
+ binary.BigEndian.PutUint32(p[1:], c.remoteId)
+ return c.writePacket(p)
+}
+
+// WriteExtended writes data to a specific extended stream. These streams are
+// used, for example, for stderr.
+func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
+ if c.sentEOF {
+ return 0, io.EOF
+ }
+ // 1 byte message type, 4 bytes remoteId, 4 bytes data length
+ opCode := byte(msgChannelData)
+ headerLength := uint32(9)
+ if extendedCode > 0 {
+ headerLength += 4
+ opCode = msgChannelExtendedData
+ }
+
+ for len(data) > 0 {
+ space := min(c.maxRemotePayload, len(data))
+ if space, err = c.remoteWin.reserve(space); err != nil {
return n, err
}
- return 0, err
+ todo := data[:space]
+
+ packet := make([]byte, headerLength+uint32(len(todo)))
+ packet[0] = opCode
+ binary.BigEndian.PutUint32(packet[1:], c.remoteId)
+ if extendedCode > 0 {
+ binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
+ }
+ binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
+ copy(packet[headerLength:], todo)
+ if err = c.writePacket(packet); err != nil {
+ return n, err
+ }
+
+ n += len(todo)
+ data = data[len(todo):]
}
- err = r.sendWindowAdj(n)
- if err == io.EOF && n > 0 {
- // sendWindowAdjust can return io.EOF if the remote peer has
- // closed the connection, however we want to defer forwarding io.EOF to the
- // caller of Read until the buffer has been drained.
- err = nil
- }
+
return n, err
}
+
+func (c *channel) handleData(packet []byte) error {
+ headerLen := 9
+ isExtendedData := packet[0] == msgChannelExtendedData
+ if isExtendedData {
+ headerLen = 13
+ }
+ if len(packet) < headerLen {
+ // malformed data packet
+ return parseError(packet[0])
+ }
+
+ var extended uint32
+ if isExtendedData {
+ extended = binary.BigEndian.Uint32(packet[5:])
+ }
+
+ length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
+ if length == 0 {
+ return nil
+ }
+ if length > c.maxIncomingPayload {
+ // TODO(hanwen): should send Disconnect?
+ return errors.New("ssh: incoming packet exceeds maximum payload size")
+ }
+
+ data := packet[headerLen:]
+ if length != uint32(len(data)) {
+ return errors.New("ssh: wrong packet length")
+ }
+
+ c.windowMu.Lock()
+ if c.myWindow < length {
+ c.windowMu.Unlock()
+ // TODO(hanwen): should send Disconnect with reason?
+ return errors.New("ssh: remote side wrote too much")
+ }
+ c.myWindow -= length
+ c.windowMu.Unlock()
+
+ if extended == 1 {
+ c.extPending.write(data)
+ } else if extended > 0 {
+ // discard other extended data.
+ } else {
+ c.pending.write(data)
+ }
+ return nil
+}
+
+func (c *channel) adjustWindow(n uint32) error {
+ c.windowMu.Lock()
+ // Since myWindow is managed on our side, and can never exceed
+ // the initial window setting, we don't worry about overflow.
+ c.myWindow += uint32(n)
+ c.windowMu.Unlock()
+ return c.sendMessage(windowAdjustMsg{
+ AdditionalBytes: uint32(n),
+ })
+}
+
+func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
+ switch extended {
+ case 1:
+ n, err = c.extPending.Read(data)
+ case 0:
+ n, err = c.pending.Read(data)
+ default:
+ return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
+ }
+
+ if n > 0 {
+ err = c.adjustWindow(uint32(n))
+ // sendWindowAdjust can return io.EOF if the remote
+ // peer has closed the connection, however we want to
+ // defer forwarding io.EOF to the caller of Read until
+ // the buffer has been drained.
+ if n > 0 && err == io.EOF {
+ err = nil
+ }
+ }
+
+ return n, err
+}
+
+func (c *channel) close() {
+ c.pending.eof()
+ c.extPending.eof()
+ close(c.msg)
+ close(c.incomingRequests)
+ c.writeMu.Lock()
+ // This is not necesary for a normal channel teardown, but if
+ // there was another error, it is.
+ c.sentClose = true
+ c.writeMu.Unlock()
+ // Unblock writers.
+ c.remoteWin.close()
+}
+
+// responseMessageReceived is called when a success or failure message is
+// received on a channel to check that such a message is reasonable for the
+// given channel.
+func (c *channel) responseMessageReceived() error {
+ if c.direction == channelInbound {
+ return errors.New("ssh: channel response message received on inbound channel")
+ }
+ if c.decided {
+ return errors.New("ssh: duplicate response received for channel")
+ }
+ c.decided = true
+ return nil
+}
+
+func (c *channel) handlePacket(packet []byte) error {
+ switch packet[0] {
+ case msgChannelData, msgChannelExtendedData:
+ return c.handleData(packet)
+ case msgChannelClose:
+ c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
+ c.mux.chanList.remove(c.localId)
+ c.close()
+ return nil
+ case msgChannelEOF:
+ // RFC 4254 is mute on how EOF affects dataExt messages but
+ // it is logical to signal EOF at the same time.
+ c.extPending.eof()
+ c.pending.eof()
+ return nil
+ }
+
+ decoded, err := decode(packet)
+ if err != nil {
+ return err
+ }
+
+ switch msg := decoded.(type) {
+ case *channelOpenFailureMsg:
+ if err := c.responseMessageReceived(); err != nil {
+ return err
+ }
+ c.mux.chanList.remove(msg.PeersId)
+ c.msg <- msg
+ case *channelOpenConfirmMsg:
+ if err := c.responseMessageReceived(); err != nil {
+ return err
+ }
+ if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+ return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
+ }
+ c.remoteId = msg.MyId
+ c.maxRemotePayload = msg.MaxPacketSize
+ c.remoteWin.add(msg.MyWindow)
+ c.msg <- msg
+ case *windowAdjustMsg:
+ if !c.remoteWin.add(msg.AdditionalBytes) {
+ return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
+ }
+ case *channelRequestMsg:
+ req := Request{
+ Type: msg.Request,
+ WantReply: msg.WantReply,
+ Payload: msg.RequestSpecificData,
+ ch: c,
+ }
+
+ c.incomingRequests <- &req
+ default:
+ c.msg <- msg
+ }
+ return nil
+}
+
+func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
+ ch := &channel{
+ remoteWin: window{Cond: newCond()},
+ myWindow: channelWindowSize,
+ pending: newBuffer(),
+ extPending: newBuffer(),
+ direction: direction,
+ incomingRequests: make(chan *Request, 16),
+ msg: make(chan interface{}, 16),
+ chanType: chanType,
+ extraData: extraData,
+ mux: m,
+ }
+ ch.localId = m.chanList.add(ch)
+ return ch
+}
+
+var errUndecided = errors.New("ssh: must Accept or Reject channel")
+var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
+
+type extChannel struct {
+ code uint32
+ ch *channel
+}
+
+func (e *extChannel) Write(data []byte) (n int, err error) {
+ return e.ch.WriteExtended(data, e.code)
+}
+
+func (e *extChannel) Read(data []byte) (n int, err error) {
+ return e.ch.ReadExtended(data, e.code)
+}
+
+func (c *channel) Accept() (Channel, <-chan *Request, error) {
+ if c.decided {
+ return nil, nil, errDecidedAlready
+ }
+ c.maxIncomingPayload = channelMaxPacket
+ confirm := channelOpenConfirmMsg{
+ PeersId: c.remoteId,
+ MyId: c.localId,
+ MyWindow: c.myWindow,
+ MaxPacketSize: c.maxIncomingPayload,
+ }
+ c.decided = true
+ if err := c.sendMessage(confirm); err != nil {
+ return nil, nil, err
+ }
+
+ return c, c.incomingRequests, nil
+}
+
+func (ch *channel) Reject(reason RejectionReason, message string) error {
+ if ch.decided {
+ return errDecidedAlready
+ }
+ reject := channelOpenFailureMsg{
+ PeersId: ch.remoteId,
+ Reason: reason,
+ Message: message,
+ Language: "en",
+ }
+ ch.decided = true
+ return ch.sendMessage(reject)
+}
+
+func (ch *channel) Read(data []byte) (int, error) {
+ if !ch.decided {
+ return 0, errUndecided
+ }
+ return ch.ReadExtended(data, 0)
+}
+
+func (ch *channel) Write(data []byte) (int, error) {
+ if !ch.decided {
+ return 0, errUndecided
+ }
+ return ch.WriteExtended(data, 0)
+}
+
+func (ch *channel) CloseWrite() error {
+ if !ch.decided {
+ return errUndecided
+ }
+ ch.sentEOF = true
+ return ch.sendMessage(channelEOFMsg{
+ PeersId: ch.remoteId})
+}
+
+func (ch *channel) Close() error {
+ if !ch.decided {
+ return errUndecided
+ }
+
+ return ch.sendMessage(channelCloseMsg{
+ PeersId: ch.remoteId})
+}
+
+// Extended returns an io.ReadWriter that sends and receives data on the given,
+// SSH extended stream. Such streams are used, for example, for stderr.
+func (ch *channel) Extended(code uint32) io.ReadWriter {
+ if !ch.decided {
+ return nil
+ }
+ return &extChannel{code, ch}
+}
+
+func (ch *channel) Stderr() io.ReadWriter {
+ return ch.Extended(1)
+}
+
+func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ if !ch.decided {
+ return false, errUndecided
+ }
+
+ if wantReply {
+ ch.sentRequestMu.Lock()
+ defer ch.sentRequestMu.Unlock()
+ }
+
+ msg := channelRequestMsg{
+ PeersId: ch.remoteId,
+ Request: name,
+ WantReply: wantReply,
+ RequestSpecificData: payload,
+ }
+
+ if err := ch.sendMessage(msg); err != nil {
+ return false, err
+ }
+
+ if wantReply {
+ m, ok := (<-ch.msg)
+ if !ok {
+ return false, io.EOF
+ }
+ switch m.(type) {
+ case *channelRequestFailureMsg:
+ return false, nil
+ case *channelRequestSuccessMsg:
+ return true, nil
+ default:
+ return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
+ }
+ }
+
+ return false, nil
+}
+
+// ackRequest either sends an ack or nack to the channel request.
+func (ch *channel) ackRequest(ok bool) error {
+ if !ch.decided {
+ return errUndecided
+ }
+
+ var msg interface{}
+ if !ok {
+ msg = channelRequestFailureMsg{
+ PeersId: ch.remoteId,
+ }
+ } else {
+ msg = channelRequestSuccessMsg{
+ PeersId: ch.remoteId,
+ }
+ }
+ return ch.sendMessage(msg)
+}
+
+func (ch *channel) ChannelType() string {
+ return ch.chanType
+}
+
+func (ch *channel) ExtraData() []byte {
+ return ch.extraData
+}
diff --git a/ssh/cipher.go b/ssh/cipher.go
index bc2e983..a58f10b 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -8,11 +8,28 @@
"crypto/aes"
"crypto/cipher"
"crypto/rc4"
+ "crypto/subtle"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
)
-// streamDump is used to dump the initial keystream for stream ciphers. It is a
-// a write-only buffer, and not intended for reading so do not require a mutex.
-var streamDump [512]byte
+const (
+ packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
+
+ // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
+ // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
+ // indicates implementations SHOULD be able to handle larger packet sizes, but then
+ // waffles on about reasonable limits.
+ //
+ // OpenSSH caps their maxPacket at 256kB so we choose to do
+ // the same. maxPacket is also used to ensure that uint32
+ // length fields do not overflow, so it should remain well
+ // below 4G.
+ maxPacket = 256 * 1024
+)
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
@@ -34,14 +51,14 @@
return rc4.NewCipher(key)
}
-type cipherMode struct {
+type streamCipherMode struct {
keySize int
ivSize int
skip int
createFunc func(key, iv []byte) (cipher.Stream, error)
}
-func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
+func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
@@ -54,6 +71,11 @@
return nil, err
}
+ var streamDump []byte
+ if c.skip > 0 {
+ streamDump = make([]byte, 512)
+ }
+
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
@@ -66,18 +88,10 @@
return stream, nil
}
-// Specifies a default set of ciphers and a preference order. This is based on
-// OpenSSH's default client preference order, minus algorithms that are not
-// implemented.
-var DefaultCipherOrder = []string{
- "aes128-ctr", "aes192-ctr", "aes256-ctr",
- "arcfour256", "arcfour128",
-}
-
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
-var cipherModes = map[string]*cipherMode{
+var cipherModes = map[string]*streamCipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
@@ -88,13 +102,237 @@
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, 1536, newRC4},
"arcfour256": {32, 0, 1536, newRC4},
+
+ // AES-GCM is not a stream cipher, so it is constructed with a
+ // special case. If we add any more non-stream ciphers, we
+ // should invest a cleaner way to do this.
+ gcmCipherID: {16, 12, 0, nil},
}
-// defaultKeyExchangeOrder specifies a default set of key exchange algorithms
-// with preferences.
-var defaultKeyExchangeOrder = []string{
- // P384 and P521 are not constant-time yet, but since we don't
- // reuse ephemeral keys, using them for ECDH should be OK.
- kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
- kexAlgoDH14SHA1, kexAlgoDH1SHA1,
+// prefixLen is the length of the packet prefix that contains the packet length
+// and number of padding bytes.
+const prefixLen = 5
+
+// streamPacketCipher is a packetCipher using a stream cipher.
+type streamPacketCipher struct {
+ mac hash.Hash
+ cipher cipher.Stream
+
+ // The following members are to avoid per-packet allocations.
+ prefix [prefixLen]byte
+ seqNumBytes [4]byte
+ padding [2 * packetSizeMultiple]byte
+ packetData []byte
+ macResult []byte
+}
+
+// readPacket reads and decrypt a single packet from the reader argument.
+func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
+ if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
+ return nil, err
+ }
+
+ s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
+ length := binary.BigEndian.Uint32(s.prefix[0:4])
+ paddingLength := uint32(s.prefix[4])
+
+ var macSize uint32
+ if s.mac != nil {
+ s.mac.Reset()
+ binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
+ s.mac.Write(s.seqNumBytes[:])
+ s.mac.Write(s.prefix[:])
+ macSize = uint32(s.mac.Size())
+ }
+
+ if length <= paddingLength+1 {
+ return nil, errors.New("ssh: invalid packet length, packet too small")
+ }
+
+ if length > maxPacket {
+ return nil, errors.New("ssh: invalid packet length, packet too large")
+ }
+
+ // the maxPacket check above ensures that length-1+macSize
+ // does not overflow.
+ if uint32(cap(s.packetData)) < length-1+macSize {
+ s.packetData = make([]byte, length-1+macSize)
+ } else {
+ s.packetData = s.packetData[:length-1+macSize]
+ }
+
+ if _, err := io.ReadFull(r, s.packetData); err != nil {
+ return nil, err
+ }
+ mac := s.packetData[length-1:]
+ data := s.packetData[:length-1]
+ s.cipher.XORKeyStream(data, data)
+
+ if s.mac != nil {
+ s.mac.Write(data)
+ s.macResult = s.mac.Sum(s.macResult[:0])
+ if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
+ return nil, errors.New("ssh: MAC failure")
+ }
+ }
+
+ return s.packetData[:length-paddingLength-1], nil
+}
+
+// writePacket encrypts and sends a packet of data to the writer argument
+func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
+ if len(packet) > maxPacket {
+ return errors.New("ssh: packet too large")
+ }
+
+ paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple
+ if paddingLength < 4 {
+ paddingLength += packetSizeMultiple
+ }
+
+ length := len(packet) + 1 + paddingLength
+ binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
+ s.prefix[4] = byte(paddingLength)
+ padding := s.padding[:paddingLength]
+ if _, err := io.ReadFull(rand, padding); err != nil {
+ return err
+ }
+
+ if s.mac != nil {
+ s.mac.Reset()
+ binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
+ s.mac.Write(s.seqNumBytes[:])
+ s.mac.Write(s.prefix[:])
+ s.mac.Write(packet)
+ s.mac.Write(padding)
+ }
+
+ s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
+ s.cipher.XORKeyStream(packet, packet)
+ s.cipher.XORKeyStream(padding, padding)
+
+ if _, err := w.Write(s.prefix[:]); err != nil {
+ return err
+ }
+ if _, err := w.Write(packet); err != nil {
+ return err
+ }
+ if _, err := w.Write(padding); err != nil {
+ return err
+ }
+
+ if s.mac != nil {
+ s.macResult = s.mac.Sum(s.macResult[:0])
+ if _, err := w.Write(s.macResult); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+type gcmCipher struct {
+ aead cipher.AEAD
+ prefix [4]byte
+ iv []byte
+ buf []byte
+}
+
+func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) {
+ c, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+
+ aead, err := cipher.NewGCM(c)
+ if err != nil {
+ return nil, err
+ }
+
+ return &gcmCipher{
+ aead: aead,
+ iv: iv,
+ }, nil
+}
+
+const gcmTagSize = 16
+
+func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
+ // Pad out to multiple of 16 bytes. This is different from the
+ // stream cipher because that encrypts the length too.
+ padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
+ if padding < 4 {
+ padding += packetSizeMultiple
+ }
+
+ length := uint32(len(packet) + int(padding) + 1)
+ binary.BigEndian.PutUint32(c.prefix[:], length)
+ if _, err := w.Write(c.prefix[:]); err != nil {
+ return err
+ }
+
+ if cap(c.buf) < int(length) {
+ c.buf = make([]byte, length)
+ } else {
+ c.buf = c.buf[:length]
+ }
+
+ c.buf[0] = padding
+ copy(c.buf[1:], packet)
+ if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
+ return err
+ }
+ c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
+ if _, err := w.Write(c.buf); err != nil {
+ return err
+ }
+ c.incIV()
+
+ return nil
+}
+
+func (c *gcmCipher) incIV() {
+ for i := 4 + 7; i >= 4; i-- {
+ c.iv[i]++
+ if c.iv[i] != 0 {
+ break
+ }
+ }
+}
+
+func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
+ if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
+ return nil, err
+ }
+ length := binary.BigEndian.Uint32(c.prefix[:])
+ if length > maxPacket {
+ return nil, errors.New("ssh: max packet length exceeded.")
+ }
+
+ if cap(c.buf) < int(length+gcmTagSize) {
+ c.buf = make([]byte, length+gcmTagSize)
+ } else {
+ c.buf = c.buf[:length+gcmTagSize]
+ }
+
+ if _, err := io.ReadFull(r, c.buf); err != nil {
+ return nil, err
+ }
+
+ plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
+ if err != nil {
+ return nil, err
+ }
+ c.incIV()
+
+ padding := plain[0]
+ if padding < 4 || padding >= 20 {
+ return nil, fmt.Errorf("ssh: illegal padding %d", padding)
+ }
+
+ if int(padding+1) >= len(plain) {
+ return nil, fmt.Errorf("ssh: padding %d too large", padding)
+ }
+ plain = plain[1 : length-uint32(padding)]
+ return plain, nil
}
diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go
index ea27bd8..e279af0 100644
--- a/ssh/cipher_test.go
+++ b/ssh/cipher_test.go
@@ -6,57 +6,54 @@
import (
"bytes"
+ "crypto"
+ "crypto/rand"
"testing"
)
-// TestCipherReversal tests that each cipher factory produces ciphers that can
-// encrypt and decrypt some data successfully.
-func TestCipherReversal(t *testing.T) {
- testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
- testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
- testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
-
- cryptBuffer := make([]byte, 32)
-
- for name, cipherMode := range cipherModes {
- encrypter, err := cipherMode.createCipher(testKey, testIv)
- if err != nil {
- t.Errorf("failed to create encrypter for %q: %s", name, err)
- continue
- }
- decrypter, err := cipherMode.createCipher(testKey, testIv)
- if err != nil {
- t.Errorf("failed to create decrypter for %q: %s", name, err)
- continue
- }
-
- copy(cryptBuffer, testData)
-
- encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
- if name == "none" {
- if !bytes.Equal(cryptBuffer, testData) {
- t.Errorf("encryption made change with 'none' cipher")
- continue
- }
- } else {
- if bytes.Equal(cryptBuffer, testData) {
- t.Errorf("encryption made no change with %q", name)
- continue
- }
- }
-
- decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
- if !bytes.Equal(cryptBuffer, testData) {
- t.Errorf("decrypted bytes not equal to input with %q", name)
- continue
+func TestDefaultCiphersExist(t *testing.T) {
+ for _, cipherAlgo := range supportedCiphers {
+ if _, ok := cipherModes[cipherAlgo]; !ok {
+ t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}
-func TestDefaultCiphersExist(t *testing.T) {
- for _, cipherAlgo := range DefaultCipherOrder {
- if _, ok := cipherModes[cipherAlgo]; !ok {
- t.Errorf("default cipher %q is unknown", cipherAlgo)
+func TestPacketCiphers(t *testing.T) {
+ for cipher := range cipherModes {
+ kr := &kexResult{Hash: crypto.SHA1}
+ algs := directionAlgorithms{
+ Cipher: cipher,
+ MAC: "hmac-sha1",
+ Compression: "none",
+ }
+ client, err := newPacketCipher(clientKeys, algs, kr)
+ if err != nil {
+ t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
+ continue
+ }
+ server, err := newPacketCipher(clientKeys, algs, kr)
+ if err != nil {
+ t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
+ continue
+ }
+
+ want := "bla bla"
+ input := []byte(want)
+ buf := &bytes.Buffer{}
+ if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
+ t.Errorf("writePacket(%q): %v", cipher, err)
+ continue
+ }
+
+ packet, err := server.readPacket(0, buf)
+ if err != nil {
+ t.Errorf("readPacket(%q): %v", cipher, err)
+ continue
+ }
+
+ if string(packet) != want {
+ t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
}
}
}
diff --git a/ssh/client.go b/ssh/client.go
index e2d2557..a8d5235 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -5,403 +5,158 @@
package ssh
import (
- "crypto/rand"
- "encoding/binary"
"errors"
"fmt"
- "io"
"net"
"sync"
)
-// ClientConn represents the client side of an SSH connection.
-type ClientConn struct {
- transport *transport
- config *ClientConfig
- chanList // channels associated with this connection
- forwardList // forwarded tcpip connections from the remote side
- globalRequest
+// Client implements a traditional SSH client that supports shells,
+// subprocesses, port forwarding and tunneled dialing.
+type Client struct {
+ Conn
- // Address as passed to the Dial function.
- dialAddress string
-
- serverVersion string
+ forwards forwardList // forwarded tcpip connections from the remote side
+ mu sync.Mutex
+ channelHandlers map[string]chan NewChannel
}
-type globalRequest struct {
- sync.Mutex
- response chan interface{}
+// HandleChannelOpen returns a channel on which NewChannel requests
+// for the given type are sent. If the type already is being handled,
+// nil is returned. The channel is closed when the connection is closed.
+func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.channelHandlers == nil {
+ // The SSH channel has been closed.
+ c := make(chan NewChannel)
+ close(c)
+ return c
+ }
+
+ ch := c.channelHandlers[channelType]
+ if ch != nil {
+ return nil
+ }
+
+ ch = make(chan NewChannel, 16)
+ c.channelHandlers[channelType] = ch
+ return ch
}
-// Client returns a new SSH client connection using c as the underlying transport.
-func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
- return clientWithAddress(c, "", config)
+// NewClient creates a Client on top of the given connection.
+func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
+ conn := &Client{
+ Conn: c,
+ channelHandlers: make(map[string]chan NewChannel, 1),
+ }
+
+ go conn.handleGlobalRequests(reqs)
+ go conn.handleChannelOpens(chans)
+ go func() {
+ conn.Wait()
+ conn.forwards.closeAll()
+ }()
+ go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
+ return conn
}
-func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
- conn := &ClientConn{
- transport: newTransport(c, config.rand(), true /* is client */),
- config: config,
- globalRequest: globalRequest{response: make(chan interface{}, 1)},
- dialAddress: addr,
+// NewClientConn establishes an authenticated SSH connection using c
+// as the underlying transport. The Request and NewChannel channels
+// must be serviced or the connection will hang.
+func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
+ fullConf := *config
+ fullConf.SetDefaults()
+ conn := &connection{
+ sshConn: sshConn{conn: c},
}
- if err := conn.handshake(); err != nil {
- conn.transport.Close()
- return nil, fmt.Errorf("handshake failed: %v", err)
+ if err := conn.clientHandshake(addr, &fullConf); err != nil {
+ c.Close()
+ return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
}
- go conn.mainLoop()
- return conn, nil
+ conn.mux = newMux(conn.transport)
+ return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
}
-// Close closes the connection.
-func (c *ClientConn) Close() error { return c.transport.Close() }
-
-// LocalAddr returns the local network address.
-func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
-
-// RemoteAddr returns the remote network address.
-func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
-
-// handshake performs the client side key exchange. See RFC 4253 Section 7.
-func (c *ClientConn) handshake() error {
- clientVersion := []byte(packageVersion)
- if c.config.ClientVersion != "" {
- clientVersion = []byte(c.config.ClientVersion)
+// clientHandshake performs the client side key exchange. See RFC 4253 Section
+// 7.
+func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
+ c.clientVersion = []byte(packageVersion)
+ if config.ClientVersion != "" {
+ c.clientVersion = []byte(config.ClientVersion)
}
- serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion)
- if err != nil {
- return err
- }
- c.serverVersion = string(serverVersion)
-
- clientKexInit := kexInitMsg{
- KexAlgos: c.config.Crypto.kexes(),
- ServerHostKeyAlgos: supportedHostKeyAlgos,
- CiphersClientServer: c.config.Crypto.ciphers(),
- CiphersServerClient: c.config.Crypto.ciphers(),
- MACsClientServer: c.config.Crypto.macs(),
- MACsServerClient: c.config.Crypto.macs(),
- CompressionClientServer: supportedCompressions,
- CompressionServerClient: supportedCompressions,
- }
- kexInitPacket := marshal(msgKexInit, clientKexInit)
- if err := c.transport.writePacket(kexInitPacket); err != nil {
- return err
- }
- packet, err := c.transport.readPacket()
+ var err error
+ c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
if err != nil {
return err
}
- var serverKexInit kexInitMsg
- if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
+ c.transport = newClientTransport(
+ newTransport(c.sshConn.conn, config.Rand, true /* is client */),
+ c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
+ if err := c.transport.requestKeyChange(); err != nil {
return err
}
- algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit)
- if algs == nil {
- return errors.New("ssh: no common algorithms")
- }
-
- if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
- // The server sent a Kex message for the wrong algorithm,
- // which we have to ignore.
- if _, err := c.transport.readPacket(); err != nil {
- return err
- }
- }
-
- kex, ok := kexAlgoMap[algs.kex]
- if !ok {
- return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
- }
-
- magics := handshakeMagics{
- clientVersion: clientVersion,
- serverVersion: serverVersion,
- clientKexInit: kexInitPacket,
- serverKexInit: packet,
- }
- result, err := kex.Client(c.transport, c.config.rand(), &magics)
- if err != nil {
+ if packet, err := c.transport.readPacket(); err != nil {
return err
+ } else if packet[0] != msgNewKeys {
+ return unexpectedMessageError(msgNewKeys, packet[0])
}
-
- err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, result.Signature)
- if err != nil {
- return err
- }
-
- if checker := c.config.HostKeyChecker; checker != nil {
- err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
- if err != nil {
- return err
- }
- }
-
- c.transport.prepareKeyChange(algs, result)
-
- if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
- return err
- }
- if packet, err = c.transport.readPacket(); err != nil {
- return err
- }
- if packet[0] != msgNewKeys {
- return UnexpectedMessageError{msgNewKeys, packet[0]}
- }
- return c.authenticate()
+ return c.clientAuthenticate(config)
}
-// Verify the host key obtained in the key exchange.
-func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error {
- hostKey, rest, ok := ParsePublicKey(hostKeyBytes)
- if len(rest) > 0 || !ok {
- return errors.New("ssh: could not parse hostkey")
- }
-
- sig, rest, ok := parseSignatureBody(signature)
+// verifyHostKeySignature verifies the host key obtained in the key
+// exchange.
+func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error {
+ sig, rest, ok := parseSignatureBody(result.Signature)
if len(rest) > 0 || !ok {
return errors.New("ssh: signature parse error")
}
- if sig.Format != hostKeyAlgo {
- return fmt.Errorf("ssh: unexpected signature type %q", sig.Format)
- }
- if !hostKey.Verify(data, sig.Blob) {
- return errors.New("ssh: host key signature error")
- }
- return nil
+ return hostKey.Verify(result.H, sig)
}
-// mainLoop reads incoming messages and routes channel messages
-// to their respective ClientChans.
-func (c *ClientConn) mainLoop() {
- defer func() {
- c.transport.Close()
- c.chanList.closeAll()
- c.forwardList.closeAll()
- }()
-
- for {
- packet, err := c.transport.readPacket()
- if err != nil {
- break
- }
- // TODO(dfc) A note on blocking channel use.
- // The msg, data and dataExt channels of a clientChan can
- // cause this loop to block indefinitely if the consumer does
- // not service them.
- switch packet[0] {
- case msgChannelData:
- if len(packet) < 9 {
- // malformed data packet
- return
- }
- remoteId := binary.BigEndian.Uint32(packet[1:5])
- length := binary.BigEndian.Uint32(packet[5:9])
- packet = packet[9:]
-
- if length != uint32(len(packet)) {
- return
- }
- ch, ok := c.getChan(remoteId)
- if !ok {
- return
- }
- ch.stdout.write(packet)
- case msgChannelExtendedData:
- if len(packet) < 13 {
- // malformed data packet
- return
- }
- remoteId := binary.BigEndian.Uint32(packet[1:5])
- datatype := binary.BigEndian.Uint32(packet[5:9])
- length := binary.BigEndian.Uint32(packet[9:13])
- packet = packet[13:]
-
- if length != uint32(len(packet)) {
- return
- }
- // RFC 4254 5.2 defines data_type_code 1 to be data destined
- // for stderr on interactive sessions. Other data types are
- // silently discarded.
- if datatype == 1 {
- ch, ok := c.getChan(remoteId)
- if !ok {
- return
- }
- ch.stderr.write(packet)
- }
- default:
- decoded, err := decode(packet)
- if err != nil {
- if _, ok := err.(UnexpectedMessageError); ok {
- fmt.Printf("mainLoop: unexpected message: %v\n", err)
- continue
- }
- return
- }
- switch msg := decoded.(type) {
- case *channelOpenMsg:
- c.handleChanOpen(msg)
- case *channelOpenConfirmMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelOpenFailureMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelCloseMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.Close()
- close(ch.msg)
- c.chanList.remove(msg.PeersId)
- case *channelEOFMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.stdout.eof()
- // RFC 4254 is mute on how EOF affects dataExt messages but
- // it is logical to signal EOF at the same time.
- ch.stderr.eof()
- case *channelRequestSuccessMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelRequestFailureMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelRequestMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *windowAdjustMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- if !ch.remoteWin.add(msg.AdditionalBytes) {
- // invalid window update
- return
- }
- case *globalRequestMsg:
- // This handles keepalive messages and matches
- // the behaviour of OpenSSH.
- if msg.WantReply {
- c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
- }
- case *globalRequestSuccessMsg, *globalRequestFailureMsg:
- c.globalRequest.response <- msg
- case *disconnectMsg:
- return
- default:
- fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
- }
- }
- }
-}
-
-// Handle channel open messages from the remote side.
-func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
- if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
- c.sendConnectionFailed(msg.PeersId)
- }
-
- switch msg.ChanType {
- case "forwarded-tcpip":
- laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
- if !ok {
- // invalid request
- c.sendConnectionFailed(msg.PeersId)
- return
- }
-
- l, ok := c.forwardList.lookup(*laddr)
- if !ok {
- // TODO: print on a more structured log.
- fmt.Println("could not find forward list entry for", laddr)
- // Section 7.2, implementations MUST reject spurious incoming
- // connections.
- c.sendConnectionFailed(msg.PeersId)
- return
- }
- raddr, rest, ok := parseTCPAddr(rest)
- if !ok {
- // invalid request
- c.sendConnectionFailed(msg.PeersId)
- return
- }
- ch := c.newChan(c.transport)
- ch.remoteId = msg.PeersId
- ch.remoteWin.add(msg.PeersWindow)
- ch.maxPacket = msg.MaxPacketSize
-
- m := channelOpenConfirmMsg{
- PeersId: ch.remoteId,
- MyId: ch.localId,
- MyWindow: channelWindowSize,
- MaxPacketSize: channelMaxPacketSize,
- }
-
- c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
- l <- forward{ch, raddr}
- default:
- // unknown channel type
- m := channelOpenFailureMsg{
- PeersId: msg.PeersId,
- Reason: UnknownChannelType,
- Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType),
- Language: "en_US.UTF-8",
- }
- c.transport.writePacket(marshal(msgChannelOpenFailure, m))
- }
-}
-
-// sendGlobalRequest sends a global request message as specified
-// in RFC4254 section 4. To correctly synchronise messages, a lock
-// is held internally until a response is returned.
-func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
- c.globalRequest.Lock()
- defer c.globalRequest.Unlock()
- if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
+// NewSession opens a new Session for this client. (A session is a remote
+// execution of a program.)
+func (c *Client) NewSession() (*Session, error) {
+ ch, in, err := c.OpenChannel("session", nil)
+ if err != nil {
return nil, err
}
- r := <-c.globalRequest.response
- if r, ok := r.(*globalRequestSuccessMsg); ok {
- return r, nil
- }
- return nil, errors.New("request failed")
+ return newSession(ch, in)
}
-// sendConnectionFailed rejects an incoming channel identified
-// by remoteId.
-func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
- m := channelOpenFailureMsg{
- PeersId: remoteId,
- Reason: ConnectionFailed,
- Message: "invalid request",
- Language: "en_US.UTF-8",
+func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
+ for r := range incoming {
+ // This handles keepalive messages and matches
+ // the behaviour of OpenSSH.
+ r.Reply(false, nil)
}
- return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
+}
+
+// handleChannelOpens channel open messages from the remote side.
+func (c *Client) handleChannelOpens(in <-chan NewChannel) {
+ for ch := range in {
+ c.mu.Lock()
+ handler := c.channelHandlers[ch.ChannelType()]
+ c.mu.Unlock()
+
+ if handler != nil {
+ handler <- ch
+ } else {
+ ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
+ }
+ }
+
+ c.mu.Lock()
+ for _, ch := range c.channelHandlers {
+ close(ch)
+ }
+ c.channelHandlers = nil
+ c.mu.Unlock()
}
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
@@ -413,7 +168,7 @@
return nil, b, false
}
port, b, ok := parseUint32(b)
- if !ok {
+ if !ok || port == 0 || port > 65535 {
return nil, b, false
}
ip := net.ParseIP(string(addr))
@@ -423,102 +178,44 @@
return &net.TCPAddr{IP: ip, Port: int(port)}, b, true
}
-// Dial connects to the given network address using net.Dial and
-// then initiates a SSH handshake, returning the resulting client connection.
-func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
+// Dial starts a client connection to the given SSH server. It is a
+// convenience function that connects to the given network address,
+// initiates the SSH handshake, and then sets up a Client. For access
+// to incoming channels and requests, use net.Dial with NewClientConn
+// instead.
+func Dial(network, addr string, config *ClientConfig) (*Client, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
- return clientWithAddress(conn, addr, config)
+ c, chans, reqs, err := NewClientConn(conn, addr, config)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(c, chans, reqs), nil
}
-// A ClientConfig structure is used to configure a ClientConn. After one has
-// been passed to an SSH function it must not be modified.
+// A ClientConfig structure is used to configure a Client. It must not be
+// modified after having been passed to an SSH function.
type ClientConfig struct {
- // Rand provides the source of entropy for key exchange. If Rand is
- // nil, the cryptographic random reader in package crypto/rand will
- // be used.
- Rand io.Reader
+ // Config contains configuration that is shared between clients and
+ // servers.
+ Config
- // The username to authenticate.
+ // User contains the username to authenticate as.
User string
- // A slice of ClientAuth methods. Only the first instance
- // of a particular RFC 4252 method will be used during authentication.
- Auth []ClientAuth
+ // Auth contains possible authentication methods to use with the
+ // server. Only the first instance of a particular RFC 4252 method will
+ // be used during authentication.
+ Auth []AuthMethod
- // HostKeyChecker, if not nil, is called during the cryptographic
- // handshake to validate the server's host key. A nil HostKeyChecker
+ // HostKeyCallback, if not nil, is called during the cryptographic
+ // handshake to validate the server's host key. A nil HostKeyCallback
// implies that all host keys are accepted.
- HostKeyChecker HostKeyChecker
+ HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
- // Cryptographic-related configuration.
- Crypto CryptoConfig
-
- // The identification string that will be used for the connection.
- // If empty, a reasonable default is used.
+ // ClientVersion contains the version identification string that will
+ // be used for the connection. If empty, a reasonable default is used.
ClientVersion string
}
-
-func (c *ClientConfig) rand() io.Reader {
- if c.Rand == nil {
- return rand.Reader
- }
- return c.Rand
-}
-
-// Thread safe channel list.
-type chanList struct {
- // protects concurrent access to chans
- sync.Mutex
- // chans are indexed by the local id of the channel, clientChan.localId.
- // The PeersId value of messages received by ClientConn.mainLoop is
- // used to locate the right local clientChan in this slice.
- chans []*clientChan
-}
-
-// Allocate a new ClientChan with the next avail local id.
-func (c *chanList) newChan(p packetConn) *clientChan {
- c.Lock()
- defer c.Unlock()
- for i := range c.chans {
- if c.chans[i] == nil {
- ch := newClientChan(p, uint32(i))
- c.chans[i] = ch
- return ch
- }
- }
- i := len(c.chans)
- ch := newClientChan(p, uint32(i))
- c.chans = append(c.chans, ch)
- return ch
-}
-
-func (c *chanList) getChan(id uint32) (*clientChan, bool) {
- c.Lock()
- defer c.Unlock()
- if id >= uint32(len(c.chans)) {
- return nil, false
- }
- return c.chans[id], true
-}
-
-func (c *chanList) remove(id uint32) {
- c.Lock()
- defer c.Unlock()
- c.chans[id] = nil
-}
-
-func (c *chanList) closeAll() {
- c.Lock()
- defer c.Unlock()
-
- for _, ch := range c.chans {
- if ch == nil {
- continue
- }
- ch.Close()
- close(ch.msg)
- }
-}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 29be0ca..5b7aa30 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -5,16 +5,16 @@
package ssh
import (
+ "bytes"
"errors"
"fmt"
"io"
- "net"
)
-// authenticate authenticates with the remote server. See RFC 4252.
-func (c *ClientConn) authenticate() error {
+// clientAuthenticate authenticates with the remote server. See RFC 4252.
+func (c *connection) clientAuthenticate(config *ClientConfig) error {
// initiate user auth session
- if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+ if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.transport.readPacket()
@@ -22,14 +22,15 @@
return err
}
var serviceAccept serviceAcceptMsg
- if err := unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
+ if err := Unmarshal(packet, &serviceAccept); err != nil {
return err
}
+
// during the authentication phase the client first attempts the "none" method
// then any untried methods suggested by the server.
- tried, remain := make(map[string]bool), make(map[string]bool)
- for auth := ClientAuth(new(noneAuth)); auth != nil; {
- ok, methods, err := auth.auth(c.transport.sessionID, c.config.User, c.transport, c.config.rand())
+ tried := make(map[string]bool)
+ for auth := AuthMethod(new(noneAuth)); auth != nil; {
+ ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
if err != nil {
return err
}
@@ -38,45 +39,35 @@
return nil
}
tried[auth.method()] = true
- delete(remain, auth.method())
- for _, meth := range methods {
- if tried[meth] {
- // if we've tried meth already, skip it.
- continue
- }
- remain[meth] = true
- }
+
auth = nil
- for _, a := range c.config.Auth {
- if remain[a.method()] {
- auth = a
- break
+ for _, a := range config.Auth {
+ candidateMethod := a.method()
+ for _, meth := range methods {
+ if meth != candidateMethod {
+ continue
+ }
+ if !tried[meth] {
+ auth = a
+ break
+ }
}
}
}
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried))
}
-func keys(m map[string]bool) (s []string) {
- for k := range m {
- s = append(s, k)
+func keys(m map[string]bool) []string {
+ s := make([]string, 0, len(m))
+
+ for key := range m {
+ s = append(s, key)
}
- return
+ return s
}
-// HostKeyChecker represents a database of known server host keys.
-type HostKeyChecker interface {
- // Check is called during the handshake to check server's
- // public key for unexpected changes. The hostKey argument is
- // in SSH wire format. It can be parsed using
- // ssh.ParsePublicKey. The address before DNS resolution is
- // passed in the addr argument, so the key can also be checked
- // against the hostname.
- Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
-}
-
-// A ClientAuth represents an instance of an RFC 4252 authentication method.
-type ClientAuth interface {
+// An AuthMethod represents an instance of an RFC 4252 authentication method.
+type AuthMethod interface {
// auth authenticates user over transport t.
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
@@ -91,7 +82,7 @@
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
- if err := c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
+ if err := c.writePacket(Marshal(&userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
@@ -106,29 +97,31 @@
return "none"
}
-// "password" authentication, RFC 4252 Section 8.
-type passwordAuth struct {
- ClientPassword
-}
+// passwordCallback is an AuthMethod that fetches the password through
+// a function call, e.g. by prompting the user.
+type passwordCallback func() (password string, err error)
-func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct {
- User string
+ User string `sshtype:"50"`
Service string
Method string
Reply bool
Password string
}
- pw, err := p.Password(user)
+ pw, err := cb()
+ // REVIEW NOTE: is there a need to support skipping a password attempt?
+ // The program may only find out that the user doesn't have a password
+ // when prompting.
if err != nil {
return false, nil, err
}
- if err := c.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
+ if err := c.writePacket(Marshal(&passwordAuthMsg{
User: user,
Service: serviceSSH,
- Method: "password",
+ Method: cb.method(),
Reply: false,
Password: pw,
})); err != nil {
@@ -138,106 +131,93 @@
return handleAuthResponse(c)
}
-func (p *passwordAuth) method() string {
+func (cb passwordCallback) method() string {
return "password"
}
-// A ClientPassword implements access to a client's passwords.
-type ClientPassword interface {
- // Password returns the password to use for user.
- Password(user string) (password string, err error)
+// Password returns an AuthMethod using the given password.
+func Password(secret string) AuthMethod {
+ return passwordCallback(func() (string, error) { return secret, nil })
}
-// ClientAuthPassword returns a ClientAuth using password authentication.
-func ClientAuthPassword(impl ClientPassword) ClientAuth {
- return &passwordAuth{impl}
-}
-
-// ClientKeyring implements access to a client key ring.
-type ClientKeyring interface {
- // Key returns the i'th Publickey, or nil if no key exists at i.
- Key(i int) (key PublicKey, err error)
-
- // Sign returns a signature of the given data using the i'th key
- // and the supplied random source.
- Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
-}
-
-// "publickey" authentication, RFC 4252 Section 7.
-type publickeyAuth struct {
- ClientKeyring
+// PasswordCallback returns an AuthMethod that uses a callback for
+// fetching a password.
+func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
+ return passwordCallback(prompt)
}
type publickeyAuthMsg struct {
- User string
+ User string `sshtype:"50"`
Service string
Method string
// HasSig indicates to the receiver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
- Pubkey string
- // Sig is defined as []byte so marshal will exclude it during validateKey
+ PubKey []byte
+ // Sig is tagged with "rest" so Marshal will exclude it during
+ // validateKey
Sig []byte `ssh:"rest"`
}
-func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+// publicKeyCallback is an AuthMethod that uses a set of key
+// pairs for authentication.
+type publicKeyCallback func() ([]Signer, error)
+
+func (cb publicKeyCallback) method() string {
+ return "publickey"
+}
+
+func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
- var index int
- // a map of public keys to their index in the keyring
- validKeys := make(map[int]PublicKey)
- for {
- key, err := p.Key(index)
- if err != nil {
- return false, nil, err
- }
- if key == nil {
- // no more keys in the keyring
- break
- }
-
- if ok, err := p.validateKey(key, user, c); ok {
- validKeys[index] = key
+ signers, err := cb()
+ if err != nil {
+ return false, nil, err
+ }
+ var validKeys []Signer
+ for _, signer := range signers {
+ if ok, err := validateKey(signer.PublicKey(), user, c); ok {
+ validKeys = append(validKeys, signer)
} else {
if err != nil {
return false, nil, err
}
}
- index++
}
// methods that may continue if this auth is not successful.
var methods []string
- for i, key := range validKeys {
- pubkey := MarshalPublicKey(key)
- algoname := key.PublicKeyAlgo()
- data := buildDataSignedForAuth(session, userAuthRequestMsg{
+ for _, signer := range validKeys {
+ pub := signer.PublicKey()
+
+ pubKey := pub.Marshal()
+ sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
- Method: p.method(),
- }, []byte(algoname), pubkey)
- sigBlob, err := p.Sign(i, rand, data)
+ Method: cb.method(),
+ }, []byte(pub.Type()), pubKey))
if err != nil {
return false, nil, err
}
+
// manually wrap the serialized signature in a string
- s := serializeSignature(key.PublicKeyAlgo(), sigBlob)
+ s := Marshal(sign)
sig := make([]byte, stringLength(len(s)))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
- Method: p.method(),
+ Method: cb.method(),
HasSig: true,
- Algoname: algoname,
- Pubkey: string(pubkey),
+ Algoname: pub.Type(),
+ PubKey: pubKey,
Sig: sig,
}
- p := marshal(msgUserAuthRequest, msg)
+ p := Marshal(&msg)
if err := c.writePacket(p); err != nil {
return false, nil, err
}
@@ -252,28 +232,27 @@
return false, methods, nil
}
-// validateKey validates the key provided it is acceptable to the server.
-func (p *publickeyAuth) validateKey(key PublicKey, user string, c packetConn) (bool, error) {
- pubkey := MarshalPublicKey(key)
- algoname := key.PublicKeyAlgo()
+// validateKey validates the key provided is acceptable to the server.
+func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
+ pubKey := key.Marshal()
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
- Method: p.method(),
+ Method: "publickey",
HasSig: false,
- Algoname: algoname,
- Pubkey: string(pubkey),
+ Algoname: key.Type(),
+ PubKey: pubKey,
}
- if err := c.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+ if err := c.writePacket(Marshal(&msg)); err != nil {
return false, err
}
- return p.confirmKeyAck(key, c)
+ return confirmKeyAck(key, c)
}
-func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
- pubkey := MarshalPublicKey(key)
- algoname := key.PublicKeyAlgo()
+func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
+ pubKey := key.Marshal()
+ algoname := key.Type()
for {
packet, err := c.readPacket()
@@ -284,30 +263,32 @@
case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user
case msgUserAuthPubKeyOk:
- msg := userAuthPubKeyOkMsg{}
- if err := unmarshal(&msg, packet, msgUserAuthPubKeyOk); err != nil {
+ var msg userAuthPubKeyOkMsg
+ if err := Unmarshal(packet, &msg); err != nil {
return false, err
}
- if msg.Algo != algoname || msg.PubKey != string(pubkey) {
+ if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) {
return false, nil
}
return true, nil
case msgUserAuthFailure:
return false, nil
default:
- return false, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+ return false, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
- panic("unreachable")
}
-func (p *publickeyAuth) method() string {
- return "publickey"
+// PublicKeys returns an AuthMethod that uses the given key
+// pairs.
+func PublicKeys(signers ...Signer) AuthMethod {
+ return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
}
-// ClientAuthKeyring returns a ClientAuth using public key authentication.
-func ClientAuthKeyring(impl ClientKeyring) ClientAuth {
- return &publickeyAuth{impl}
+// PublicKeysCallback returns an AuthMethod that runs the given
+// function to obtain a list of key pairs.
+func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
+ return publicKeyCallback(getSigners)
}
// handleAuthResponse returns whether the preceding authentication request succeeded
@@ -324,8 +305,8 @@
case msgUserAuthBanner:
// TODO: add callback to present the banner to the user
case msgUserAuthFailure:
- msg := userAuthFailureMsg{}
- if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
+ var msg userAuthFailureMsg
+ if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
}
return false, msg.Methods, nil
@@ -334,98 +315,40 @@
case msgDisconnect:
return false, nil, io.EOF
default:
- return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+ return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
- panic("unreachable")
}
-// ClientAuthAgent returns a ClientAuth using public key authentication via
-// an agent.
-func ClientAuthAgent(agent *AgentClient) ClientAuth {
- return ClientAuthKeyring(&agentKeyring{agent: agent})
+// KeyboardInteractiveChallenge should print questions, optionally
+// disabling echoing (e.g. for passwords), and return all the answers.
+// Challenge may be called multiple times in a single session. After
+// successful authentication, the server may send a challenge with no
+// questions, for which the user and instruction messages should be
+// printed. RFC 4256 section 3.3 details how the UI should behave for
+// both CLI and GUI environments.
+type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
+
+// KeyboardInteractive returns a AuthMethod using a prompt/response
+// sequence controlled by the server.
+func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
+ return challenge
}
-// agentKeyring implements ClientKeyring.
-type agentKeyring struct {
- agent *AgentClient
- keys []*AgentKey
-}
-
-func (kr *agentKeyring) Key(i int) (key PublicKey, err error) {
- if kr.keys == nil {
- if kr.keys, err = kr.agent.RequestIdentities(); err != nil {
- return
- }
- }
- if i >= len(kr.keys) {
- return
- }
- return kr.keys[i].Key()
-}
-
-func (kr *agentKeyring) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
- var key PublicKey
- if key, err = kr.Key(i); err != nil {
- return
- }
- if key == nil {
- return nil, errors.New("ssh: key index out of range")
- }
- if sig, err = kr.agent.SignRequest(key, data); err != nil {
- return
- }
-
- // Unmarshal the signature.
-
- var ok bool
- if _, sig, ok = parseString(sig); !ok {
- return nil, errors.New("ssh: malformed signature response from agent")
- }
- if sig, _, ok = parseString(sig); !ok {
- return nil, errors.New("ssh: malformed signature response from agent")
- }
- return sig, nil
-}
-
-// ClientKeyboardInteractive should prompt the user for the given
-// questions.
-type ClientKeyboardInteractive interface {
- // Challenge should print the questions, optionally disabling
- // echoing (eg. for passwords), and return all the answers.
- // Challenge may be called multiple times in a single
- // session. After successful authentication, the server may
- // send a challenge with no questions, for which the user and
- // instruction messages should be printed. RFC 4256 section
- // 3.3 details how the UI should behave for both CLI and
- // GUI environments.
- Challenge(user, instruction string, questions []string, echos []bool) ([]string, error)
-}
-
-// ClientAuthKeyboardInteractive returns a ClientAuth using a
-// prompt/response sequence controlled by the server.
-func ClientAuthKeyboardInteractive(impl ClientKeyboardInteractive) ClientAuth {
- return &keyboardInteractiveAuth{impl}
-}
-
-type keyboardInteractiveAuth struct {
- ClientKeyboardInteractive
-}
-
-func (k *keyboardInteractiveAuth) method() string {
+func (cb KeyboardInteractiveChallenge) method() string {
return "keyboard-interactive"
}
-func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type initiateMsg struct {
- User string
+ User string `sshtype:"50"`
Service string
Method string
Language string
Submethods string
}
- if err := c.writePacket(marshal(msgUserAuthRequest, initiateMsg{
+ if err := c.writePacket(Marshal(&initiateMsg{
User: user,
Service: serviceSSH,
Method: "keyboard-interactive",
@@ -448,18 +371,18 @@
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
- if err := unmarshal(&msg, packet, msgUserAuthFailure); err != nil {
+ if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
}
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
default:
- return false, nil, UnexpectedMessageError{msgUserAuthInfoRequest, packet[0]}
+ return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
var msg userAuthInfoRequestMsg
- if err := unmarshal(&msg, packet, packet[0]); err != nil {
+ if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
}
@@ -478,10 +401,10 @@
}
if len(rest) != 0 {
- return false, nil, fmt.Errorf("ssh: junk following message %q", rest)
+ return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
}
- answers, err := k.Challenge(msg.User, msg.Instruction, prompts, echos)
+ answers, err := cb(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
}
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index f2fc9c6..3173255 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -6,363 +6,317 @@
import (
"bytes"
- "crypto/dsa"
- "io"
- "io/ioutil"
- "math/big"
+ "crypto/rand"
+ "errors"
+ "fmt"
"strings"
"testing"
-
- _ "crypto/sha1"
)
-// private key for mock server
-const testServerPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
-MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
-70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
-9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
-tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
-s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
-qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
-+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
-riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
-D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
-atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
-b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
-ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
-MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
-KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
-e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
-D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
-3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
-orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
-64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
-XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
-QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
-/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
-I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
-gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
-NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
------END RSA PRIVATE KEY-----`
-
-const testClientPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
-MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
-r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
-tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
-nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
-2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
-y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
-rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
------END RSA PRIVATE KEY-----`
-
-// keychain implements the ClientKeyring interface
-type keychain struct {
- keys []Signer
-}
-
-func (k *keychain) Key(i int) (PublicKey, error) {
- if i < 0 || i >= len(k.keys) {
- return nil, nil
- }
-
- return k.keys[i].PublicKey(), nil
-}
-
-func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
- return k.keys[i].Sign(rand, data)
-}
-
-func (k *keychain) add(key Signer) {
- k.keys = append(k.keys, key)
-}
-
-func (k *keychain) loadPEM(file string) error {
- buf, err := ioutil.ReadFile(file)
- if err != nil {
- return err
- }
- key, err := ParsePrivateKey(buf)
- if err != nil {
- return err
- }
- k.add(key)
- return nil
-}
-
-// password implements the ClientPassword interface
-type password string
-
-func (p password) Password(user string) (string, error) {
- return string(p), nil
-}
-
type keyboardInteractive map[string]string
-func (cr *keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
+func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) {
var answers []string
for _, q := range questions {
- answers = append(answers, (*cr)[q])
+ answers = append(answers, cr[q])
}
return answers, nil
}
// reused internally by tests
-var (
- rsaKey Signer
- dsaKey Signer
- clientKeychain = new(keychain)
- clientPassword = password("tiger")
- serverConfig = &ServerConfig{
- PasswordCallback: func(conn *ServerConn, user, pass string) bool {
- return user == "testuser" && pass == string(clientPassword)
+var clientPassword = "tiger"
+
+// tryAuth runs a handshake with a given config against an SSH server
+// with config serverConfig
+func tryAuth(t *testing.T, config *ClientConfig) error {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ certChecker := CertChecker{
+ IsAuthority: func(k PublicKey) bool {
+ return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
- PublicKeyCallback: func(conn *ServerConn, user, algo string, pubkey []byte) bool {
- key, _ := clientKeychain.Key(0)
- expected := MarshalPublicKey(key)
- algoname := key.PublicKeyAlgo()
- return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
+ UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User())
},
- KeyboardInteractiveCallback: func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool {
- ans, err := client.Challenge("user",
+ IsRevoked: func(c *Certificate) bool {
+ return c.Serial == 666
+ },
+ }
+
+ serverConfig := &ServerConfig{
+ PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
+ if conn.User() == "testuser" && string(pass) == clientPassword {
+ return nil, nil
+ }
+ return nil, errors.New("password auth failed")
+ },
+ PublicKeyCallback: certChecker.Authenticate,
+ KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) {
+ ans, err := challenge("user",
"instruction",
[]string{"question1", "question2"},
[]bool{true, true})
if err != nil {
- return false
+ return nil, err
}
- ok := user == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
- client.Challenge("user", "motd", nil, nil)
- return ok
+ ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2"
+ if ok {
+ challenge("user", "motd", nil, nil)
+ return nil, nil
+ }
+ return nil, errors.New("keyboard-interactive failed")
+ },
+ AuthLogCallback: func(conn ConnMetadata, method string, err error) {
+ t.Logf("user %q, method %q: %v", conn.User(), method, err)
},
}
-)
+ serverConfig.AddHostKey(testSigners["rsa"])
-func init() {
- var err error
- rsaKey, err = ParsePrivateKey([]byte(testServerPrivateKey))
- if err != nil {
- panic("unable to set private key: " + err.Error())
- }
- rawDSAKey := new(dsa.PrivateKey)
-
- // taken from crypto/dsa/dsa_test.go
- rawDSAKey.P, _ = new(big.Int).SetString("A9B5B793FB4785793D246BAE77E8FF63CA52F442DA763C440259919FE1BC1D6065A9350637A04F75A2F039401D49F08E066C4D275A5A65DA5684BC563C14289D7AB8A67163BFBF79D85972619AD2CFF55AB0EE77A9002B0EF96293BDD0F42685EBB2C66C327079F6C98000FBCB79AACDE1BC6F9D5C7B1A97E3D9D54ED7951FEF", 16)
- rawDSAKey.Q, _ = new(big.Int).SetString("E1D3391245933D68A0714ED34BBCB7A1F422B9C1", 16)
- rawDSAKey.G, _ = new(big.Int).SetString("634364FC25248933D01D1993ECABD0657CC0CB2CEED7ED2E3E8AECDFCDC4A25C3B15E9E3B163ACA2984B5539181F3EFF1A5E8903D71D5B95DA4F27202B77D2C44B430BB53741A8D59A8F86887525C9F2A6A5980A195EAA7F2FF910064301DEF89D3AA213E1FAC7768D89365318E370AF54A112EFBA9246D9158386BA1B4EEFDA", 16)
- rawDSAKey.Y, _ = new(big.Int).SetString("32969E5780CFE1C849A1C276D7AEB4F38A23B591739AA2FE197349AEEBD31366AEE5EB7E6C6DDB7C57D02432B30DB5AA66D9884299FAA72568944E4EEDC92EA3FBC6F39F53412FBCC563208F7C15B737AC8910DBC2D9C9B8C001E72FDC40EB694AB1F06A5A2DBD18D9E36C66F31F566742F11EC0A52E9F7B89355C02FB5D32D2", 16)
- rawDSAKey.X, _ = new(big.Int).SetString("5078D4D29795CBE76D3AACFE48C9AF0BCDBEE91A", 16)
-
- dsaKey, err = NewSignerFromKey(rawDSAKey)
- if err != nil {
- panic("NewSignerFromKey: " + err.Error())
- }
- clientKeychain.add(rsaKey)
- serverConfig.AddHostKey(rsaKey)
-}
-
-// newMockAuthServer creates a new Server bound to
-// the loopback interface. The server exits after
-// processing one handshake.
-func newMockAuthServer(t *testing.T) string {
- l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
- if err != nil {
- t.Fatalf("unable to newMockAuthServer: %s", err)
- }
- go func() {
- defer l.Close()
- c, err := l.Accept()
- if err != nil {
- t.Errorf("Unable to accept incoming connection: %v", err)
- return
- }
- if err := c.Handshake(); err != nil {
- // not Errorf because this is expected to
- // fail for some tests.
- t.Logf("Handshaking error: %v", err)
- return
- }
- defer c.Close()
- }()
- return l.Addr().String()
+ go newServer(c1, serverConfig)
+ _, _, _, err = NewClientConn(c2, "", config)
+ return err
}
func TestClientAuthPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyring(clientKeychain),
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err != nil {
+ if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
- c.Close()
}
-func TestClientAuthPassword(t *testing.T) {
+func TestAuthMethodPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthPassword(clientPassword),
+ Auth: []AuthMethod{
+ Password(clientPassword),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err != nil {
+ if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
- c.Close()
}
-func TestClientAuthWrongPassword(t *testing.T) {
- wrongPw := password("wrong")
+func TestAuthMethodWrongPassword(t *testing.T) {
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthPassword(wrongPw),
- ClientAuthKeyring(clientKeychain),
+ Auth: []AuthMethod{
+ Password("wrong"),
+ PublicKeys(testSigners["rsa"]),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err != nil {
+ if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
- c.Close()
}
-func TestClientAuthKeyboardInteractive(t *testing.T) {
+func TestAuthMethodKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "answer2",
})
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyboardInteractive(&answers),
+ Auth: []AuthMethod{
+ KeyboardInteractive(answers.Challenge),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err != nil {
+ if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
- c.Close()
}
-func TestClientAuthWrongKeyboardInteractive(t *testing.T) {
+func TestAuthMethodWrongKeyboardInteractive(t *testing.T) {
answers := keyboardInteractive(map[string]string{
"question1": "answer1",
"question2": "WRONG",
})
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyboardInteractive(&answers),
+ Auth: []AuthMethod{
+ KeyboardInteractive(answers.Challenge),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err == nil {
- c.Close()
+ if err := tryAuth(t, config); err == nil {
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive")
}
}
// the mock server will only authenticate ssh-rsa keys
-func TestClientAuthInvalidPublicKey(t *testing.T) {
- kc := new(keychain)
-
- kc.add(dsaKey)
+func TestAuthMethodInvalidPublicKey(t *testing.T) {
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyring(kc),
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["dsa"]),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err == nil {
- c.Close()
+ if err := tryAuth(t, config); err == nil {
t.Fatalf("dsa private key should not have authenticated with rsa public key")
}
}
// the client should authenticate with the second key
-func TestClientAuthRSAandDSA(t *testing.T) {
- kc := new(keychain)
- kc.add(dsaKey)
- kc.add(rsaKey)
+func TestAuthMethodRSAandDSA(t *testing.T) {
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyring(kc),
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["dsa"], testSigners["rsa"]),
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err != nil {
+ if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
}
- c.Close()
}
func TestClientHMAC(t *testing.T) {
- kc := new(keychain)
- kc.add(rsaKey)
- for _, mac := range DefaultMACOrder {
+ for _, mac := range supportedMACs {
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyring(kc),
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
},
- Crypto: CryptoConfig{
+ Config: Config{
MACs: []string{mac},
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err != nil {
+ if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
}
- c.Close()
}
}
// issue 4285.
func TestClientUnsupportedCipher(t *testing.T) {
- kc := new(keychain)
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyring(kc),
+ Auth: []AuthMethod{
+ PublicKeys(),
},
- Crypto: CryptoConfig{
+ Config: Config{
Ciphers: []string{"aes128-cbc"}, // not currently supported
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err == nil {
+ if err := tryAuth(t, config); err == nil {
t.Errorf("expected no ciphers in common")
- c.Close()
}
}
func TestClientUnsupportedKex(t *testing.T) {
- kc := new(keychain)
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthKeyring(kc),
+ Auth: []AuthMethod{
+ PublicKeys(),
},
- Crypto: CryptoConfig{
+ Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
}
- c, err := Dial("tcp", newMockAuthServer(t), config)
- if err == nil || !strings.Contains(err.Error(), "no common algorithms") {
+ if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") {
t.Errorf("got %v, expected 'no common algorithms'", err)
}
- if c != nil {
- c.Close()
+}
+
+func TestClientLoginCert(t *testing.T) {
+ cert := &Certificate{
+ Key: testPublicKeys["rsa"],
+ ValidBefore: CertTimeInfinity,
+ CertType: UserCert,
+ }
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ certSigner, err := NewCertSigner(cert, testSigners["rsa"])
+ if err != nil {
+ t.Fatalf("NewCertSigner: %v", err)
+ }
+
+ clientConfig := &ClientConfig{
+ User: "user",
+ }
+ clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
+
+ t.Log("should succeed")
+ if err := tryAuth(t, clientConfig); err != nil {
+ t.Errorf("cert login failed: %v", err)
+ }
+
+ t.Log("corrupted signature")
+ cert.Signature.Blob[0]++
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("cert login passed with corrupted sig")
+ }
+
+ t.Log("revoked")
+ cert.Serial = 666
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("revoked cert login succeeded")
+ }
+ cert.Serial = 1
+
+ t.Log("sign with wrong key")
+ cert.SignCert(rand.Reader, testSigners["dsa"])
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("cert login passed with non-authoritive key")
+ }
+
+ t.Log("host cert")
+ cert.CertType = HostCert
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("cert login passed with wrong type")
+ }
+ cert.CertType = UserCert
+
+ t.Log("principal specified")
+ cert.ValidPrincipals = []string{"user"}
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err != nil {
+ t.Errorf("cert login failed: %v", err)
+ }
+
+ t.Log("wrong principal specified")
+ cert.ValidPrincipals = []string{"fred"}
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("cert login passed with wrong principal")
+ }
+ cert.ValidPrincipals = nil
+
+ t.Log("added critical option")
+ cert.CriticalOptions = map[string]string{"root-access": "yes"}
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("cert login passed with unrecognized critical option")
+ }
+
+ t.Log("allowed source address")
+ cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"}
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err != nil {
+ t.Errorf("cert login with source-address failed: %v", err)
+ }
+
+ t.Log("disallowed source address")
+ cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"}
+ cert.SignCert(rand.Reader, testSigners["ecdsa"])
+ if err := tryAuth(t, clientConfig); err == nil {
+ t.Errorf("cert login with source-address succeeded")
}
}
diff --git a/ssh/client_test.go b/ssh/client_test.go
index f6c11b9..1fe790c 100644
--- a/ssh/client_test.go
+++ b/ssh/client_test.go
@@ -1,3 +1,7 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
package ssh
import (
@@ -7,6 +11,7 @@
func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe()
+ defer clientConn.Close()
receivedVersion := make(chan string, 1)
go func() {
version, err := readVersion(serverConn)
@@ -17,7 +22,7 @@
}
serverConn.Close()
}()
- Client(clientConn, config)
+ NewClientConn(clientConn, "", config)
actual := <-receivedVersion
if actual != expected {
t.Fatalf("got %s; want %s", actual, expected)
diff --git a/ssh/common.go b/ssh/common.go
index 4870e56..2fd7fd9 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -6,7 +6,9 @@
import (
"crypto"
+ "crypto/rand"
"fmt"
+ "io"
"sync"
_ "crypto/sha1"
@@ -21,16 +23,39 @@
serviceSSH = "ssh-connection"
)
+// supportedCiphers specifies the supported ciphers in preference order.
+var supportedCiphers = []string{
+ "aes128-ctr", "aes192-ctr", "aes256-ctr",
+ "aes128-gcm@openssh.com",
+ "arcfour256", "arcfour128",
+}
+
+// supportedKexAlgos specifies the supported key-exchange algorithms in
+// preference order.
var supportedKexAlgos = []string{
+ // P384 and P521 are not constant-time yet, but since we don't
+ // reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
}
+// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
+// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
+ CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
+ CertAlgoECDSA384v01, CertAlgoECDSA521v01,
+
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSA, KeyAlgoDSA,
}
+// supportedMACs specifies a default set of MAC algorithms in preference order.
+// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
+// because they have reached the end of their useful life.
+var supportedMACs = []string{
+ "hmac-sha1", "hmac-sha1-96",
+}
+
var supportedCompressions = []string{compressionNone}
// hashFuncs keeps the mapping of supported algorithms to their respective
@@ -48,23 +73,15 @@
CertAlgoECDSA521v01: crypto.SHA512,
}
-// UnexpectedMessageError results when the SSH message that we received didn't
+// unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
-type UnexpectedMessageError struct {
- expected, got uint8
+func unexpectedMessageError(expected, got uint8) error {
+ return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
}
-func (u UnexpectedMessageError) Error() string {
- return fmt.Sprintf("ssh: unexpected message type %d (expected %d)", u.got, u.expected)
-}
-
-// ParseError results from a malformed SSH message.
-type ParseError struct {
- msgType uint8
-}
-
-func (p ParseError) Error() string {
- return fmt.Sprintf("ssh: parse error in message type %d", p.msgType)
+// parseError results from a malformed SSH message.
+func parseError(tag uint8) error {
+ return fmt.Errorf("ssh: parse error in message type %d", tag)
}
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
@@ -90,15 +107,17 @@
return
}
+type directionAlgorithms struct {
+ Cipher string
+ MAC string
+ Compression string
+}
+
type algorithms struct {
- kex string
- hostKey string
- wCipher string
- rCipher string
- rMAC string
- wMAC string
- rCompression string
- wCompression string
+ kex string
+ hostKey string
+ w directionAlgorithms
+ r directionAlgorithms
}
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
@@ -114,32 +133,32 @@
return
}
- result.wCipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+ result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
return
}
- result.rCipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+ result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
return
}
- result.wMAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+ result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
return
}
- result.rMAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+ result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
return
}
- result.wCompression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+ result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
return
}
- result.rCompression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+ result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
return
}
@@ -147,133 +166,87 @@
return result
}
-// Cryptographic configuration common to both ServerConfig and ClientConfig.
-type CryptoConfig struct {
+// If rekeythreshold is too small, we can't make any progress sending
+// stuff.
+const minRekeyThreshold uint64 = 256
+
+// Config contains configuration data common to both ServerConfig and
+// ClientConfig.
+type Config struct {
+ // Rand provides the source of entropy for cryptographic
+ // primitives. If Rand is nil, the cryptographic random reader
+ // in package crypto/rand will be used.
+ Rand io.Reader
+
+ // The maximum number of bytes sent or received after which a
+ // new key is negotiated. It must be at least 256. If
+ // unspecified, 1 gigabyte is used.
+ RekeyThreshold uint64
+
// The allowed key exchanges algorithms. If unspecified then a
// default set of algorithms is used.
KeyExchanges []string
- // The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
- // used.
+ // The allowed cipher algorithms. If unspecified then a sensible
+ // default is used.
Ciphers []string
- // The allowed MAC algorithms. If unspecified then DefaultMACOrder is used.
+ // The allowed MAC algorithms. If unspecified then a sensible default
+ // is used.
MACs []string
}
-func (c *CryptoConfig) ciphers() []string {
+// SetDefaults sets sensible values for unset fields in config. This is
+// exported for testing: Configs passed to SSH functions are copied and have
+// default values set automatically.
+func (c *Config) SetDefaults() {
+ if c.Rand == nil {
+ c.Rand = rand.Reader
+ }
if c.Ciphers == nil {
- return DefaultCipherOrder
+ c.Ciphers = supportedCiphers
}
- return c.Ciphers
-}
-func (c *CryptoConfig) kexes() []string {
if c.KeyExchanges == nil {
- return defaultKeyExchangeOrder
+ c.KeyExchanges = supportedKexAlgos
}
- return c.KeyExchanges
-}
-func (c *CryptoConfig) macs() []string {
if c.MACs == nil {
- return DefaultMACOrder
+ c.MACs = supportedMACs
}
- return c.MACs
-}
-// serialize a signed slice according to RFC 4254 6.6. The name should
-// be a key type name, rather than a cert type name.
-func serializeSignature(name string, sig []byte) []byte {
- length := stringLength(len(name))
- length += stringLength(len(sig))
-
- ret := make([]byte, length)
- r := marshalString(ret, []byte(name))
- r = marshalString(r, sig)
-
- return ret
-}
-
-// MarshalPublicKey serializes a supported key or certificate for use
-// by the SSH wire protocol. It can be used for comparison with the
-// pubkey argument of ServerConfig's PublicKeyCallback as well as for
-// generating an authorized_keys or host_keys file.
-func MarshalPublicKey(key PublicKey) []byte {
- // See also RFC 4253 6.6.
- algoname := key.PublicKeyAlgo()
- blob := key.Marshal()
-
- length := stringLength(len(algoname))
- length += len(blob)
- ret := make([]byte, length)
- r := marshalString(ret, []byte(algoname))
- copy(r, blob)
- return ret
-}
-
-// pubAlgoToPrivAlgo returns the private key algorithm format name that
-// corresponds to a given public key algorithm format name. For most
-// public keys, the private key algorithm name is the same. For some
-// situations, such as openssh certificates, the private key algorithm and
-// public key algorithm names differ. This accounts for those situations.
-func pubAlgoToPrivAlgo(pubAlgo string) string {
- switch pubAlgo {
- case CertAlgoRSAv01:
- return KeyAlgoRSA
- case CertAlgoDSAv01:
- return KeyAlgoDSA
- case CertAlgoECDSA256v01:
- return KeyAlgoECDSA256
- case CertAlgoECDSA384v01:
- return KeyAlgoECDSA384
- case CertAlgoECDSA521v01:
- return KeyAlgoECDSA521
+ if c.RekeyThreshold == 0 {
+ // RFC 4253, section 9 suggests rekeying after 1G.
+ c.RekeyThreshold = 1 << 30
}
- return pubAlgo
+ if c.RekeyThreshold < minRekeyThreshold {
+ c.RekeyThreshold = minRekeyThreshold
+ }
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
- user := []byte(req.User)
- service := []byte(req.Service)
- method := []byte(req.Method)
-
- length := stringLength(len(sessionId))
- length += 1
- length += stringLength(len(user))
- length += stringLength(len(service))
- length += stringLength(len(method))
- length += 1
- length += stringLength(len(algo))
- length += stringLength(len(pubKey))
-
- ret := make([]byte, length)
- r := marshalString(ret, sessionId)
- r[0] = msgUserAuthRequest
- r = r[1:]
- r = marshalString(r, user)
- r = marshalString(r, service)
- r = marshalString(r, method)
- r[0] = 1
- r = r[1:]
- r = marshalString(r, algo)
- r = marshalString(r, pubKey)
- return ret
-}
-
-// safeString sanitises s according to RFC 4251, section 9.2.
-// All control characters except tab, carriage return and newline are
-// replaced by 0x20.
-func safeString(s string) string {
- out := []byte(s)
- for i, c := range out {
- if c < 0x20 && c != 0xd && c != 0xa && c != 0x9 {
- out[i] = 0x20
- }
+ data := struct {
+ Session []byte
+ Type byte
+ User string
+ Service string
+ Method string
+ Sign bool
+ Algo []byte
+ PubKey []byte
+ }{
+ sessionId,
+ msgUserAuthRequest,
+ req.User,
+ req.Service,
+ req.Method,
+ true,
+ algo,
+ pubKey,
}
- return string(out)
+ return Marshal(data)
}
func appendU16(buf []byte, n uint16) []byte {
@@ -284,6 +257,12 @@
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
}
+func appendU64(buf []byte, n uint64) []byte {
+ return append(buf,
+ byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
+ byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
+}
+
func appendInt(buf []byte, n int) []byte {
return appendU32(buf, uint32(n))
}
@@ -296,11 +275,9 @@
func appendBool(buf []byte, b bool) []byte {
if b {
- buf = append(buf, 1)
- } else {
- buf = append(buf, 0)
+ return append(buf, 1)
}
- return buf
+ return append(buf, 0)
}
// newCond is a helper to hide the fact that there is no usable zero
@@ -311,7 +288,9 @@
// wishing to write to a channel.
type window struct {
*sync.Cond
- win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
+ win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
+ writeWaiters int
+ closed bool
}
// add adds win to the amount of window available
@@ -335,18 +314,44 @@
return true
}
+// close sets the window to closed, so all reservations fail
+// immediately.
+func (w *window) close() {
+ w.L.Lock()
+ w.closed = true
+ w.Broadcast()
+ w.L.Unlock()
+}
+
// reserve reserves win from the available window capacity.
// If no capacity remains, reserve will block. reserve may
// return less than requested.
-func (w *window) reserve(win uint32) uint32 {
+func (w *window) reserve(win uint32) (uint32, error) {
+ var err error
w.L.Lock()
- for w.win == 0 {
+ w.writeWaiters++
+ w.Broadcast()
+ for w.win == 0 && !w.closed {
w.Wait()
}
+ w.writeWaiters--
if w.win < win {
win = w.win
}
w.win -= win
+ if w.closed {
+ err = io.EOF
+ }
w.L.Unlock()
- return win
+ return win, err
+}
+
+// waitWriterBlocked waits until some goroutine is blocked for further
+// writes. It is used in tests only.
+func (w *window) waitWriterBlocked() {
+ w.Cond.L.Lock()
+ for w.writeWaiters == 0 {
+ w.Cond.Wait()
+ }
+ w.Cond.L.Unlock()
}
diff --git a/ssh/connection.go b/ssh/connection.go
new file mode 100644
index 0000000..93551e2
--- /dev/null
+++ b/ssh/connection.go
@@ -0,0 +1,144 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "fmt"
+ "net"
+)
+
+// OpenChannelError is returned if the other side rejects an
+// OpenChannel request.
+type OpenChannelError struct {
+ Reason RejectionReason
+ Message string
+}
+
+func (e *OpenChannelError) Error() string {
+ return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
+}
+
+// ConnMetadata holds metadata for the connection.
+type ConnMetadata interface {
+ // User returns the user ID for this connection.
+ // It is empty if no authentication is used.
+ User() string
+
+ // SessionID returns the sesson hash, also denoted by H.
+ SessionID() []byte
+
+ // ClientVersion returns the client's version string as hashed
+ // into the session ID.
+ ClientVersion() []byte
+
+ // ServerVersion returns the client's version string as hashed
+ // into the session ID.
+ ServerVersion() []byte
+
+ // RemoteAddr returns the remote address for this connection.
+ RemoteAddr() net.Addr
+
+ // LocalAddr returns the local address for this connection.
+ LocalAddr() net.Addr
+}
+
+// Conn represents an SSH connection for both server and client roles.
+// Conn is the basis for implementing an application layer, such
+// as ClientConn, which implements the traditional shell access for
+// clients.
+type Conn interface {
+ ConnMetadata
+
+ // SendRequest sends a global request, and returns the
+ // reply. If wantReply is true, it returns the response status
+ // and payload. See also RFC4254, section 4.
+ SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
+
+ // OpenChannel tries to open an channel. If the request is
+ // rejected, it returns *OpenChannelError. On success it returns
+ // the SSH Channel and a Go channel for incoming, out-of-band
+ // requests. The Go channel must be serviced, or the
+ // connection will hang.
+ OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
+
+ // Close closes the underlying network connection
+ Close() error
+
+ // Wait blocks until the connection has shut down, and returns the
+ // error causing the shutdown.
+ Wait() error
+
+ // TODO(hanwen): consider exposing:
+ // RequestKeyChange
+ // Disconnect
+}
+
+// DiscardRequests consumes and rejects all requests from the
+// passed-in channel.
+func DiscardRequests(in <-chan *Request) {
+ for req := range in {
+ if req.WantReply {
+ req.Reply(false, nil)
+ }
+ }
+}
+
+// A connection represents an incoming connection.
+type connection struct {
+ transport *handshakeTransport
+ sshConn
+
+ // The connection protocol.
+ *mux
+}
+
+func (c *connection) Close() error {
+ return c.sshConn.conn.Close()
+}
+
+// sshconn provides net.Conn metadata, but disallows direct reads and
+// writes.
+type sshConn struct {
+ conn net.Conn
+
+ user string
+ sessionID []byte
+ clientVersion []byte
+ serverVersion []byte
+}
+
+func dup(src []byte) []byte {
+ dst := make([]byte, len(src))
+ copy(dst, src)
+ return dst
+}
+
+func (c *sshConn) User() string {
+ return c.user
+}
+
+func (c *sshConn) RemoteAddr() net.Addr {
+ return c.conn.RemoteAddr()
+}
+
+func (c *sshConn) Close() error {
+ return c.conn.Close()
+}
+
+func (c *sshConn) LocalAddr() net.Addr {
+ return c.conn.LocalAddr()
+}
+
+func (c *sshConn) SessionID() []byte {
+ return dup(c.sessionID)
+}
+
+func (c *sshConn) ClientVersion() []byte {
+ return dup(c.clientVersion)
+}
+
+func (c *sshConn) ServerVersion() []byte {
+ return dup(c.serverVersion)
+}
diff --git a/ssh/doc.go b/ssh/doc.go
index 22ff338..d4d16f0 100644
--- a/ssh/doc.go
+++ b/ssh/doc.go
@@ -13,7 +13,6 @@
References:
[PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
- [PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
*/
package ssh
diff --git a/ssh/example_test.go b/ssh/example_test.go
index a88a677..d9d6a54 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -9,17 +9,23 @@
"fmt"
"io/ioutil"
"log"
+ "net"
"net/http"
"code.google.com/p/go.crypto/ssh/terminal"
)
-func ExampleListen() {
+func ExampleNewServerConn() {
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ServerConfig{
- PasswordCallback: func(conn *ServerConn, user, pass string) bool {
- return user == "testuser" && pass == "tiger"
+ PasswordCallback: func(c ConnMetadata, pass []byte) (*Permissions, error) {
+ // Should use constant-time compare (or better, salt+hash) in
+ // a production setting.
+ if c.User() == "testuser" && string(pass) == "tiger" {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("password rejected for %q", c.User())
},
}
@@ -37,50 +43,65 @@
// Once a ServerConfig has been configured, connections can be
// accepted.
- listener, err := Listen("tcp", "0.0.0.0:2022", config)
+ listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
panic("failed to listen for connection")
}
- sConn, err := listener.Accept()
+ nConn, err := listener.Accept()
if err != nil {
panic("failed to accept incoming connection")
}
- if err := sConn.Handshake(); err != nil {
+
+ // Before use, a handshake must be performed on the incoming
+ // net.Conn.
+ _, chans, reqs, err := NewServerConn(nConn, config)
+ if err != nil {
panic("failed to handshake")
}
+ // The incoming Request channel must be serviced.
+ go DiscardRequests(reqs)
- // A ServerConn multiplexes several channels, which must
- // themselves be Accepted.
- for {
- // Accept reads from the connection, demultiplexes packets
- // to their corresponding channels and returns when a new
- // channel request is seen. Some goroutine must always be
- // calling Accept; otherwise no messages will be forwarded
- // to the channels.
- channel, err := sConn.Accept()
- if err != nil {
- panic("error from Accept")
- }
-
+ // Service the incoming Channel channel.
+ for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is
// "session" and ServerShell may be used to present a simple
// terminal interface.
- if channel.ChannelType() != "session" {
- channel.Reject(UnknownChannelType, "unknown channel type")
+ if newChannel.ChannelType() != "session" {
+ newChannel.Reject(UnknownChannelType, "unknown channel type")
continue
}
- channel.Accept()
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ panic("could not accept channel.")
+ }
+
+ // Sessions have out-of-band requests such as "shell",
+ // "pty-req" and "env". Here we handle only the
+ // "shell" request.
+ go func(in <-chan *Request) {
+ for req := range in {
+ ok := false
+ switch req.Type {
+ case "shell":
+ ok = true
+ if len(req.Payload) > 0 {
+ // We don't accept any
+ // commands, only the
+ // default shell.
+ ok = false
+ }
+ }
+ req.Reply(ok, nil)
+ }
+ }(requests)
term := terminal.NewTerminal(channel, "> ")
- serverTerm := &ServerTerminal{
- Term: term,
- Channel: channel,
- }
+
go func() {
defer channel.Close()
for {
- line, err := serverTerm.ReadLine()
+ line, err := term.ReadLine()
if err != nil {
break
}
@@ -95,13 +116,11 @@
// the "password" authentication method is supported.
//
// To authenticate with the remote server you must pass at least one
- // implementation of ClientAuth via the Auth field in ClientConfig.
+ // implementation of AuthMethod via the Auth field in ClientConfig.
config := &ClientConfig{
User: "username",
- Auth: []ClientAuth{
- // ClientAuthPassword wraps a ClientPassword implementation
- // in a type that implements ClientAuth.
- ClientAuthPassword(password("yourpassword")),
+ Auth: []AuthMethod{
+ Password("yourpassword"),
},
}
client, err := Dial("tcp", "yourserver.com:22", config)
@@ -127,11 +146,11 @@
fmt.Println(b.String())
}
-func ExampleClientConn_Listen() {
+func ExampleClient_Listen() {
config := &ClientConfig{
User: "username",
- Auth: []ClientAuth{
- ClientAuthPassword(password("password")),
+ Auth: []AuthMethod{
+ Password("password"),
},
}
// Dial your ssh server.
@@ -158,8 +177,8 @@
// Create client config
config := &ClientConfig{
User: "username",
- Auth: []ClientAuth{
- ClientAuthPassword(password("password")),
+ Auth: []AuthMethod{
+ Password("password"),
},
}
// Connect to ssh server
diff --git a/ssh/handshake.go b/ssh/handshake.go
new file mode 100644
index 0000000..a1e2c23
--- /dev/null
+++ b/ssh/handshake.go
@@ -0,0 +1,393 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "crypto/rand"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+)
+
+// debugHandshake, if set, prints messages sent and received. Key
+// exchange messages are printed as if DH were used, so the debug
+// messages are wrong when using ECDH.
+const debugHandshake = false
+
+// keyingTransport is a packet based transport that supports key
+// changes. It need not be thread-safe. It should pass through
+// msgNewKeys in both directions.
+type keyingTransport interface {
+ packetConn
+
+ // prepareKeyChange sets up a key change. The key change for a
+ // direction will be effected if a msgNewKeys message is sent
+ // or received.
+ prepareKeyChange(*algorithms, *kexResult) error
+
+ // getSessionID returns the session ID. prepareKeyChange must
+ // have been called once.
+ getSessionID() []byte
+}
+
+// rekeyingTransport is the interface of handshakeTransport that we
+// (internally) expose to ClientConn and ServerConn.
+type rekeyingTransport interface {
+ packetConn
+
+ // requestKeyChange asks the remote side to change keys. All
+ // writes are blocked until the key change succeeds, which is
+ // signaled by reading a msgNewKeys.
+ requestKeyChange() error
+
+ // getSessionID returns the session ID. This is only valid
+ // after the first key change has completed.
+ getSessionID() []byte
+}
+
+// handshakeTransport implements rekeying on top of a keyingTransport
+// and offers a thread-safe writePacket() interface.
+type handshakeTransport struct {
+ conn keyingTransport
+ config *Config
+
+ serverVersion []byte
+ clientVersion []byte
+
+ hostKeys []Signer // If hostKeys are given, we are the server.
+
+ // On read error, incoming is closed, and readError is set.
+ incoming chan []byte
+ readError error
+
+ // data for host key checking
+ hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+ dialAddress string
+ remoteAddr net.Addr
+
+ readSinceKex uint64
+
+ // Protects the writing side of the connection
+ mu sync.Mutex
+ cond *sync.Cond
+ sentInitPacket []byte
+ sentInitMsg *kexInitMsg
+ writtenSinceKex uint64
+ writeError error
+}
+
+func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
+ t := &handshakeTransport{
+ conn: conn,
+ serverVersion: serverVersion,
+ clientVersion: clientVersion,
+ incoming: make(chan []byte, 16),
+ config: config,
+ }
+ t.cond = sync.NewCond(&t.mu)
+ return t
+}
+
+func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
+ t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
+ t.dialAddress = dialAddr
+ t.remoteAddr = addr
+ t.hostKeyCallback = config.HostKeyCallback
+ go t.readLoop()
+ return t
+}
+
+func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
+ t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
+ t.hostKeys = config.hostKeys
+ go t.readLoop()
+ return t
+}
+
+func (t *handshakeTransport) getSessionID() []byte {
+ return t.conn.getSessionID()
+}
+
+func (t *handshakeTransport) id() string {
+ if len(t.hostKeys) > 0 {
+ return "server"
+ }
+ return "client"
+}
+
+func (t *handshakeTransport) readPacket() ([]byte, error) {
+ p, ok := <-t.incoming
+ if !ok {
+ return nil, t.readError
+ }
+ return p, nil
+}
+
+func (t *handshakeTransport) readLoop() {
+ for {
+ p, err := t.readOnePacket()
+ if err != nil {
+ t.readError = err
+ close(t.incoming)
+ break
+ }
+ if p[0] == msgIgnore || p[0] == msgDebug {
+ continue
+ }
+ t.incoming <- p
+ }
+}
+
+func (t *handshakeTransport) readOnePacket() ([]byte, error) {
+ if t.readSinceKex > t.config.RekeyThreshold {
+ if err := t.requestKeyChange(); err != nil {
+ return nil, err
+ }
+ }
+
+ p, err := t.conn.readPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ t.readSinceKex += uint64(len(p))
+ if debugHandshake {
+ msg, err := decode(p)
+ log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
+ }
+ if p[0] != msgKexInit {
+ return p, nil
+ }
+ err = t.enterKeyExchange(p)
+
+ t.mu.Lock()
+ if err != nil {
+ // drop connection
+ t.conn.Close()
+ t.writeError = err
+ }
+
+ if debugHandshake {
+ log.Printf("%s exited key exchange, err %v", t.id(), err)
+ }
+
+ // Unblock writers.
+ t.sentInitMsg = nil
+ t.sentInitPacket = nil
+ t.cond.Broadcast()
+ t.writtenSinceKex = 0
+ t.mu.Unlock()
+
+ if err != nil {
+ return nil, err
+ }
+
+ t.readSinceKex = 0
+ return []byte{msgNewKeys}, nil
+}
+
+// sendKexInit sends a key change message, and returns the message
+// that was sent. After initiating the key change, all writes will be
+// blocked until the change is done, and a failed key change will
+// close the underlying transport. This function is safe for
+// concurrent use by multiple goroutines.
+func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.sendKexInitLocked()
+}
+
+func (t *handshakeTransport) requestKeyChange() error {
+ _, _, err := t.sendKexInit()
+ return err
+}
+
+// sendKexInitLocked sends a key change message. t.mu must be locked
+// while this happens.
+func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
+ // kexInits may be sent either in response to the other side,
+ // or because our side wants to initiate a key change, so we
+ // may have already sent a kexInit. In that case, don't send a
+ // second kexInit.
+ if t.sentInitMsg != nil {
+ return t.sentInitMsg, t.sentInitPacket, nil
+ }
+ msg := &kexInitMsg{
+ KexAlgos: t.config.KeyExchanges,
+ CiphersClientServer: t.config.Ciphers,
+ CiphersServerClient: t.config.Ciphers,
+ MACsClientServer: t.config.MACs,
+ MACsServerClient: t.config.MACs,
+ CompressionClientServer: supportedCompressions,
+ CompressionServerClient: supportedCompressions,
+ }
+ io.ReadFull(rand.Reader, msg.Cookie[:])
+
+ if len(t.hostKeys) > 0 {
+ for _, k := range t.hostKeys {
+ msg.ServerHostKeyAlgos = append(
+ msg.ServerHostKeyAlgos, k.PublicKey().Type())
+ }
+ } else {
+ msg.ServerHostKeyAlgos = supportedHostKeyAlgos
+ }
+ packet := Marshal(msg)
+
+ // writePacket destroys the contents, so save a copy.
+ packetCopy := make([]byte, len(packet))
+ copy(packetCopy, packet)
+
+ if err := t.conn.writePacket(packetCopy); err != nil {
+ return nil, nil, err
+ }
+
+ t.sentInitMsg = msg
+ t.sentInitPacket = packet
+ return msg, packet, nil
+}
+
+func (t *handshakeTransport) writePacket(p []byte) error {
+ t.mu.Lock()
+ if t.writtenSinceKex > t.config.RekeyThreshold {
+ t.sendKexInitLocked()
+ }
+ for t.sentInitMsg != nil {
+ t.cond.Wait()
+ }
+ if t.writeError != nil {
+ return t.writeError
+ }
+ t.writtenSinceKex += uint64(len(p))
+
+ var err error
+ switch p[0] {
+ case msgKexInit:
+ err = errors.New("ssh: only handshakeTransport can send kexInit")
+ case msgNewKeys:
+ err = errors.New("ssh: only handshakeTransport can send newKeys")
+ default:
+ err = t.conn.writePacket(p)
+ }
+ t.mu.Unlock()
+ return err
+}
+
+func (t *handshakeTransport) Close() error {
+ return t.conn.Close()
+}
+
+// enterKeyExchange runs the key exchange.
+func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
+ if debugHandshake {
+ log.Printf("%s entered key exchange", t.id())
+ }
+ myInit, myInitPacket, err := t.sendKexInit()
+ if err != nil {
+ return err
+ }
+
+ otherInit := &kexInitMsg{}
+ if err := Unmarshal(otherInitPacket, otherInit); err != nil {
+ return err
+ }
+
+ magics := handshakeMagics{
+ clientVersion: t.clientVersion,
+ serverVersion: t.serverVersion,
+ clientKexInit: otherInitPacket,
+ serverKexInit: myInitPacket,
+ }
+
+ clientInit := otherInit
+ serverInit := myInit
+ if len(t.hostKeys) == 0 {
+ clientInit = myInit
+ serverInit = otherInit
+
+ magics.clientKexInit = myInitPacket
+ magics.serverKexInit = otherInitPacket
+ }
+
+ algs := findAgreedAlgorithms(clientInit, serverInit)
+ if algs == nil {
+ return errors.New("ssh: no common algorithms")
+ }
+
+ // We don't send FirstKexFollows, but we handle receiving it.
+ if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
+ // other side sent a kex message for the wrong algorithm,
+ // which we have to ignore.
+ if _, err := t.conn.readPacket(); err != nil {
+ return err
+ }
+ }
+
+ kex, ok := kexAlgoMap[algs.kex]
+ if !ok {
+ return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
+ }
+
+ var result *kexResult
+ if len(t.hostKeys) > 0 {
+ result, err = t.server(kex, algs, &magics)
+ } else {
+ result, err = t.client(kex, algs, &magics)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ t.conn.prepareKeyChange(algs, result)
+ if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
+ return err
+ }
+ if packet, err := t.conn.readPacket(); err != nil {
+ return err
+ } else if packet[0] != msgNewKeys {
+ return unexpectedMessageError(msgNewKeys, packet[0])
+ }
+ return nil
+}
+
+func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
+ var hostKey Signer
+ for _, k := range t.hostKeys {
+ if algs.hostKey == k.PublicKey().Type() {
+ hostKey = k
+ }
+ }
+
+ r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
+ return r, err
+}
+
+func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
+ result, err := kex.Client(t.conn, t.config.Rand, magics)
+ if err != nil {
+ return nil, err
+ }
+
+ hostKey, err := ParsePublicKey(result.HostKey)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := verifyHostKeySignature(hostKey, result); err != nil {
+ return nil, err
+ }
+
+ if t.hostKeyCallback != nil {
+ err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return result, nil
+}
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
new file mode 100644
index 0000000..613c498
--- /dev/null
+++ b/ssh/handshake_test.go
@@ -0,0 +1,311 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "testing"
+)
+
+type testChecker struct {
+ calls []string
+}
+
+func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+ if dialAddr == "bad" {
+ return fmt.Errorf("dialAddr is bad")
+ }
+
+ if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
+ return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
+ }
+
+ t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
+
+ return nil
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer listener.Close()
+ c1, err := net.Dial("tcp", listener.Addr().String())
+ if err != nil {
+ return nil, nil, err
+ }
+
+ c2, err := listener.Accept()
+ if err != nil {
+ c1.Close()
+ return nil, nil, err
+ }
+
+ return c1, c2, nil
+}
+
+func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
+ a, b, err := netPipe()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ trC := newTransport(a, rand.Reader, true)
+ trS := newTransport(b, rand.Reader, false)
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.SetDefaults()
+ server = newServerTransport(trS, v, v, serverConf)
+
+ return client, server, nil
+}
+
+func TestHandshakeBasic(t *testing.T) {
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ go func() {
+ // Client writes a bunch of stuff, and does a key
+ // change in the middle. This should not confuse the
+ // handshake in progress
+ for i := 0; i < 10; i++ {
+ p := []byte{msgRequestSuccess, byte(i)}
+ if err := trC.writePacket(p); err != nil {
+ t.Fatalf("sendPacket: %v", err)
+ }
+ if i == 5 {
+ // halfway through, we request a key change.
+ _, _, err := trC.sendKexInit()
+ if err != nil {
+ t.Fatalf("sendKexInit: %v", err)
+ }
+ }
+ }
+ trC.Close()
+ }()
+
+ // Server checks that client messages come in cleanly
+ i := 0
+ for {
+ p, err := trS.readPacket()
+ if err != nil {
+ break
+ }
+ if p[0] == msgNewKeys {
+ continue
+ }
+ want := []byte{msgRequestSuccess, byte(i)}
+ if bytes.Compare(p, want) != 0 {
+ t.Errorf("message %d: got %q, want %q", i, p, want)
+ }
+ i++
+ }
+ if i != 10 {
+ t.Errorf("received %d messages, want 10.", i)
+ }
+
+ // If all went well, we registered exactly 1 key change.
+ if len(checker.calls) != 1 {
+ t.Fatalf("got %d host key checks, want 1", len(checker.calls))
+ }
+
+ pub := testSigners["ecdsa"].PublicKey()
+ want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
+ if want != checker.calls[0] {
+ t.Errorf("got %q want %q for host key check", checker.calls[0], want)
+ }
+}
+
+func TestHandshakeError(t *testing.T) {
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ // send a packet
+ packet := []byte{msgRequestSuccess, 42}
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+
+ // Now request a key change.
+ _, _, err = trC.sendKexInit()
+ if err != nil {
+ t.Errorf("sendKexInit: %v", err)
+ }
+
+ // the key change will fail, and afterwards we can't write.
+ if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
+ t.Errorf("writePacket after botched rekey succeeded.")
+ }
+
+ readback, err := trS.readPacket()
+ if err != nil {
+ t.Fatalf("server closed too soon: %v", err)
+ }
+ if bytes.Compare(readback, packet) != 0 {
+ t.Errorf("got %q want %q", readback, packet)
+ }
+ readback, err = trS.readPacket()
+ if err == nil {
+ t.Errorf("got a message %q after failed key change", readback)
+ }
+}
+
+func TestHandshakeTwice(t *testing.T) {
+ checker := &testChecker{}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+
+ defer trC.Close()
+ defer trS.Close()
+
+ // send a packet
+ packet := make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+
+ // Now request a key change.
+ _, _, err = trC.sendKexInit()
+ if err != nil {
+ t.Errorf("sendKexInit: %v", err)
+ }
+
+ // Send another packet. Use a fresh one, since writePacket destroys.
+ packet = make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+
+ // 2nd key change.
+ _, _, err = trC.sendKexInit()
+ if err != nil {
+ t.Errorf("sendKexInit: %v", err)
+ }
+
+ packet = make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+
+ packet = make([]byte, 5)
+ packet[0] = msgRequestSuccess
+ for i := 0; i < 5; i++ {
+ msg, err := trS.readPacket()
+ if err != nil {
+ t.Fatalf("server closed too soon: %v", err)
+ }
+ if msg[0] == msgNewKeys {
+ continue
+ }
+
+ if bytes.Compare(msg, packet) != 0 {
+ t.Errorf("packet %d: got %q want %q", i, msg, packet)
+ }
+ }
+ if len(checker.calls) != 2 {
+ t.Errorf("got %d key changes, want 2", len(checker.calls))
+ }
+}
+
+func TestHandshakeAutoRekeyWrite(t *testing.T) {
+ checker := &testChecker{}
+ clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+ clientConf.RekeyThreshold = 500
+ trC, trS, err := handshakePair(clientConf, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ for i := 0; i < 5; i++ {
+ packet := make([]byte, 251)
+ packet[0] = msgRequestSuccess
+ if err := trC.writePacket(packet); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+ }
+
+ j := 0
+ for ; j < 5; j++ {
+ _, err := trS.readPacket()
+ if err != nil {
+ break
+ }
+ }
+
+ if j != 5 {
+ t.Errorf("got %d, want 5 messages", j)
+ }
+
+ if len(checker.calls) != 2 {
+ t.Errorf("got %d key changes, wanted 2", len(checker.calls))
+ }
+}
+
+type syncChecker struct {
+ called chan int
+}
+
+func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+ t.called <- 1
+ return nil
+}
+
+func TestHandshakeAutoRekeyRead(t *testing.T) {
+ sync := &syncChecker{make(chan int, 2)}
+ clientConf := &ClientConfig{
+ HostKeyCallback: sync.Check,
+ }
+ clientConf.RekeyThreshold = 500
+
+ trC, trS, err := handshakePair(clientConf, "addr")
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ packet := make([]byte, 501)
+ packet[0] = msgRequestSuccess
+ if err := trS.writePacket(packet); err != nil {
+ t.Fatalf("writePacket: %v", err)
+ }
+ // While we read out the packet, a key change will be
+ // initiated.
+ if _, err := trC.readPacket(); err != nil {
+ t.Fatalf("readPacket(client): %v", err)
+ }
+
+ <-sync.called
+}
diff --git a/ssh/kex.go b/ssh/kex.go
index d2e3b70..6a835c7 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -30,10 +30,10 @@
// Shared secret. See also RFC 4253, section 8.
K []byte
- // Host key as hashed into H
+ // Host key as hashed into H.
HostKey []byte
- // Signature of H
+ // Signature of H.
Signature []byte
// A cryptographic hash function that matches the security
@@ -94,7 +94,7 @@
kexDHInit := kexDHInitMsg{
X: X,
}
- if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
+ if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
return nil, err
}
@@ -104,7 +104,7 @@
}
var kexDHReply kexDHReplyMsg
- if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
+ if err = Unmarshal(packet, &kexDHReply); err != nil {
return nil, err
}
@@ -138,7 +138,7 @@
return
}
var kexDHInit kexDHInitMsg
- if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+ if err = Unmarshal(packet, &kexDHInit); err != nil {
return
}
@@ -153,7 +153,7 @@
return nil, err
}
- hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+ hostKeyBytes := priv.PublicKey().Marshal()
h := hashFunc.New()
magics.write(h)
@@ -179,7 +179,7 @@
Y: Y,
Signature: sig,
}
- packet = marshal(msgKexDHReply, kexDHReply)
+ packet = Marshal(&kexDHReply)
err = c.writePacket(packet)
return &kexResult{
@@ -207,7 +207,7 @@
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
}
- serialized := marshal(msgKexECDHInit, kexInit)
+ serialized := Marshal(&kexInit)
if err := c.writePacket(serialized); err != nil {
return nil, err
}
@@ -218,7 +218,7 @@
}
var reply kexECDHReplyMsg
- if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
+ if err = Unmarshal(packet, &reply); err != nil {
return nil, err
}
@@ -297,7 +297,7 @@
}
var kexECDHInit kexECDHInitMsg
- if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
+ if err = Unmarshal(packet, &kexECDHInit); err != nil {
return nil, err
}
@@ -314,7 +314,7 @@
return nil, err
}
- hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+ hostKeyBytes := priv.PublicKey().Marshal()
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
@@ -346,7 +346,7 @@
Signature: sig,
}
- serialized := marshal(msgKexECDHReply, reply)
+ serialized := Marshal(&reply)
if err := c.writePacket(serialized); err != nil {
return nil, err
}
diff --git a/ssh/kex_test.go b/ssh/kex_test.go
index 1e931a3..0db5f9b 100644
--- a/ssh/kex_test.go
+++ b/ssh/kex_test.go
@@ -29,7 +29,7 @@
c <- kexResultErr{r, e}
}()
go func() {
- r, e := kex.Server(b, rand.Reader, &magics, ecdsaKey)
+ r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
s <- kexResultErr{r, e}
}()
diff --git a/ssh/keys.go b/ssh/keys.go
index b41fefc..e8af511 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -33,7 +33,7 @@
// parsePubKey parses a public key of the given algorithm.
// Use ParsePublicKey for keys with prepended algorithm.
-func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, ok bool) {
+func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
switch algo {
case KeyAlgoRSA:
return parseRSA(in)
@@ -42,15 +42,19 @@
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
return parseECDSA(in)
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
- return parseOpenSSHCertV01(in, algo)
+ cert, err := parseCert(in, certToPrivAlgo(algo))
+ if err != nil {
+ return nil, nil, err
+ }
+ return cert, nil, nil
}
- return nil, nil, false
+ return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err)
}
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
// (see sshd(8) manual page) once the options and key type fields have been
// removed.
-func parseAuthorizedKey(in []byte) (out PublicKey, comment string, ok bool) {
+func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) {
in = bytes.TrimSpace(in)
i := bytes.IndexAny(in, " \t")
@@ -62,20 +66,20 @@
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key)))
n, err := base64.StdEncoding.Decode(key, base64Key)
if err != nil {
- return
+ return nil, "", err
}
key = key[:n]
- out, _, ok = ParsePublicKey(key)
- if !ok {
- return nil, "", false
+ out, err = ParsePublicKey(key)
+ if err != nil {
+ return nil, "", err
}
comment = string(bytes.TrimSpace(in[i:]))
- return
+ return out, comment, nil
}
// ParseAuthorizedKeys parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page.
-func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, ok bool) {
+func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
for len(in) > 0 {
end := bytes.IndexByte(in, '\n')
if end != -1 {
@@ -102,8 +106,8 @@
continue
}
- if out, comment, ok = parseAuthorizedKey(in[i:]); ok {
- return
+ if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
+ return out, comment, options, rest, nil
}
// No key type recognised. Maybe there's an options field at
@@ -143,38 +147,42 @@
continue
}
- if out, comment, ok = parseAuthorizedKey(in[i:]); ok {
+ if out, comment, err = parseAuthorizedKey(in[i:]); err == nil {
options = candidateOptions
- return
+ return out, comment, options, rest, nil
}
in = rest
continue
}
- return
+ return nil, "", nil, nil, errors.New("ssh: no key found")
}
// ParsePublicKey parses an SSH public key formatted for use in
// the SSH wire protocol according to RFC 4253, section 6.6.
-func ParsePublicKey(in []byte) (out PublicKey, rest []byte, ok bool) {
+func ParsePublicKey(in []byte) (out PublicKey, err error) {
algo, in, ok := parseString(in)
if !ok {
- return
+ return nil, errShortRead
+ }
+ var rest []byte
+ out, rest, err = parsePubKey(in, string(algo))
+ if len(rest) > 0 {
+ return nil, errors.New("ssh: trailing junk in public key")
}
- return parsePubKey(in, string(algo))
+ return out, err
}
-// MarshalAuthorizedKey returns a byte stream suitable for inclusion
-// in an OpenSSH authorized_keys file following the format specified
-// in the sshd(8) manual page.
+// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
+// authorized_keys file. The return value ends with newline.
func MarshalAuthorizedKey(key PublicKey) []byte {
b := &bytes.Buffer{}
- b.WriteString(key.PublicKeyAlgo())
+ b.WriteString(key.Type())
b.WriteByte(' ')
e := base64.NewEncoder(base64.StdEncoding, b)
- e.Write(MarshalPublicKey(key))
+ e.Write(key.Marshal())
e.Close()
b.WriteByte('\n')
return b.Bytes()
@@ -182,84 +190,81 @@
// PublicKey is an abstraction of different types of public keys.
type PublicKey interface {
- // PrivateKeyAlgo returns the name of the encryption system.
- PrivateKeyAlgo() string
-
- // PublicKeyAlgo returns the algorithm for the public key,
- // which may be different from PrivateKeyAlgo for certificates.
- PublicKeyAlgo() string
+ // Type returns the key's type, e.g. "ssh-rsa".
+ Type() string
// Marshal returns the serialized key data in SSH wire format,
- // without the name prefix. Callers should typically use
- // MarshalPublicKey().
+ // with the name prefix.
Marshal() []byte
// Verify that sig is a signature on the given data using this
// key. This function will hash the data appropriately first.
- Verify(data []byte, sigBlob []byte) bool
+ Verify(data []byte, sig *Signature) error
}
-// A Signer is can create signatures that verify against a public key.
+// A Signer can create signatures that verify against a public key.
type Signer interface {
// PublicKey returns an associated PublicKey instance.
PublicKey() PublicKey
// Sign returns raw signature for the given data. This method
// will apply the hash specified for the keytype to the data.
- Sign(rand io.Reader, data []byte) ([]byte, error)
+ Sign(rand io.Reader, data []byte) (*Signature, error)
}
type rsaPublicKey rsa.PublicKey
-func (r *rsaPublicKey) PrivateKeyAlgo() string {
+func (r *rsaPublicKey) Type() string {
return "ssh-rsa"
}
-func (r *rsaPublicKey) PublicKeyAlgo() string {
- return r.PrivateKeyAlgo()
-}
-
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
-func parseRSA(in []byte) (out PublicKey, rest []byte, ok bool) {
- key := new(rsa.PublicKey)
-
- bigE, in, ok := parseInt(in)
- if !ok || bigE.BitLen() > 24 {
- return
+func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
+ var w struct {
+ E *big.Int
+ N *big.Int
+ Rest []byte `ssh:"rest"`
}
- e := bigE.Int64()
+ if err := Unmarshal(in, &w); err != nil {
+ return nil, nil, err
+ }
+
+ if w.E.BitLen() > 24 {
+ return nil, nil, errors.New("ssh: exponent too large")
+ }
+ e := w.E.Int64()
if e < 3 || e&1 == 0 {
- ok = false
- return
+ return nil, nil, errors.New("ssh: incorrect exponent")
}
+
+ var key rsa.PublicKey
key.E = int(e)
-
- if key.N, in, ok = parseInt(in); !ok {
- return
- }
-
- ok = true
- return (*rsaPublicKey)(key), in, ok
+ key.N = w.N
+ return (*rsaPublicKey)(&key), w.Rest, nil
}
func (r *rsaPublicKey) Marshal() []byte {
- // See RFC 4253, section 6.6.
e := new(big.Int).SetInt64(int64(r.E))
- length := intLength(e)
- length += intLength(r.N)
-
- ret := make([]byte, length)
- rest := marshalInt(ret, e)
- marshalInt(rest, r.N)
-
- return ret
+ wirekey := struct {
+ Name string
+ E *big.Int
+ N *big.Int
+ }{
+ KeyAlgoRSA,
+ e,
+ r.N,
+ }
+ return Marshal(&wirekey)
}
-func (r *rsaPublicKey) Verify(data []byte, sig []byte) bool {
+func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
+ if sig.Format != r.Type() {
+ return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
+ }
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
- return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig) == nil
+ return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob)
}
type rsaPrivateKey struct {
@@ -270,64 +275,66 @@
return (*rsaPublicKey)(&r.PrivateKey.PublicKey)
}
-func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
- return rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
+ blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
+ if err != nil {
+ return nil, err
+ }
+ return &Signature{
+ Format: r.PublicKey().Type(),
+ Blob: blob,
+ }, nil
}
type dsaPublicKey dsa.PublicKey
-func (r *dsaPublicKey) PrivateKeyAlgo() string {
+func (r *dsaPublicKey) Type() string {
return "ssh-dss"
}
-func (r *dsaPublicKey) PublicKeyAlgo() string {
- return r.PrivateKeyAlgo()
-}
-
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
-func parseDSA(in []byte) (out PublicKey, rest []byte, ok bool) {
- key := new(dsa.PublicKey)
-
- if key.P, in, ok = parseInt(in); !ok {
- return
+func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
+ var w struct {
+ P, Q, G, Y *big.Int
+ Rest []byte `ssh:"rest"`
+ }
+ if err := Unmarshal(in, &w); err != nil {
+ return nil, nil, err
}
- if key.Q, in, ok = parseInt(in); !ok {
- return
+ key := &dsaPublicKey{
+ Parameters: dsa.Parameters{
+ P: w.P,
+ Q: w.Q,
+ G: w.G,
+ },
+ Y: w.Y,
}
-
- if key.G, in, ok = parseInt(in); !ok {
- return
- }
-
- if key.Y, in, ok = parseInt(in); !ok {
- return
- }
-
- ok = true
- return (*dsaPublicKey)(key), in, ok
+ return key, w.Rest, nil
}
-func (r *dsaPublicKey) Marshal() []byte {
- // See RFC 4253, section 6.6.
- length := intLength(r.P)
- length += intLength(r.Q)
- length += intLength(r.G)
- length += intLength(r.Y)
+func (k *dsaPublicKey) Marshal() []byte {
+ w := struct {
+ Name string
+ P, Q, G, Y *big.Int
+ }{
+ k.Type(),
+ k.P,
+ k.Q,
+ k.G,
+ k.Y,
+ }
- ret := make([]byte, length)
- rest := marshalInt(ret, r.P)
- rest = marshalInt(rest, r.Q)
- rest = marshalInt(rest, r.G)
- marshalInt(rest, r.Y)
-
- return ret
+ return Marshal(&w)
}
-func (k *dsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
+func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
+ if sig.Format != k.Type() {
+ return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
+ }
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
@@ -337,12 +344,15 @@
// r, followed by s (which are 160-bit integers, without lengths or
// padding, unsigned, and in network byte order).
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
- if len(sigBlob) != 40 {
- return false
+ if len(sig.Blob) != 40 {
+ return errors.New("ssh: DSA signature parse error")
}
- r := new(big.Int).SetBytes(sigBlob[:20])
- s := new(big.Int).SetBytes(sigBlob[20:])
- return dsa.Verify((*dsa.PublicKey)(k), digest, r, s)
+ r := new(big.Int).SetBytes(sig.Blob[:20])
+ s := new(big.Int).SetBytes(sig.Blob[20:])
+ if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) {
+ return nil
+ }
+ return errors.New("ssh: signature did not verify")
}
type dsaPrivateKey struct {
@@ -353,7 +363,7 @@
return (*dsaPublicKey)(&k.PrivateKey.PublicKey)
}
-func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
@@ -363,14 +373,21 @@
}
sig := make([]byte, 40)
- copy(sig[:20], r.Bytes())
- copy(sig[20:], s.Bytes())
- return sig, nil
+ rb := r.Bytes()
+ sb := s.Bytes()
+
+ copy(sig[20-len(rb):20], rb)
+ copy(sig[40-len(sb):], sb)
+
+ return &Signature{
+ Format: k.PublicKey().Type(),
+ Blob: sig,
+ }, nil
}
type ecdsaPublicKey ecdsa.PublicKey
-func (key *ecdsaPublicKey) PrivateKeyAlgo() string {
+func (key *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID()
}
@@ -387,7 +404,7 @@
}
func supportedEllipticCurve(curve elliptic.Curve) bool {
- return (curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521())
+ return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
}
// ecHash returns the hash to match the given elliptic curve, see RFC
@@ -403,15 +420,11 @@
return crypto.SHA512
}
-func (key *ecdsaPublicKey) PublicKeyAlgo() string {
- return key.PrivateKeyAlgo()
-}
-
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
-func parseECDSA(in []byte) (out PublicKey, rest []byte, ok bool) {
- var identifier []byte
- if identifier, in, ok = parseString(in); !ok {
- return
+func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
+ identifier, in, ok := parseString(in)
+ if !ok {
+ return nil, nil, errShortRead
}
key := new(ecdsa.PublicKey)
@@ -424,38 +437,42 @@
case "nistp521":
key.Curve = elliptic.P521()
default:
- ok = false
- return
+ return nil, nil, errors.New("ssh: unsupported curve")
}
var keyBytes []byte
if keyBytes, in, ok = parseString(in); !ok {
- return
+ return nil, nil, errShortRead
}
key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
if key.X == nil || key.Y == nil {
- ok = false
- return
+ return nil, nil, errors.New("ssh: invalid curve point")
}
- return (*ecdsaPublicKey)(key), in, ok
+ return (*ecdsaPublicKey)(key), in, nil
}
func (key *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
+ w := struct {
+ Name string
+ ID string
+ Key []byte
+ }{
+ key.Type(),
+ key.nistID(),
+ keyBytes,
+ }
- ID := key.nistID()
- length := stringLength(len(ID))
- length += stringLength(len(keyBytes))
-
- ret := make([]byte, length)
- r := marshalString(ret, []byte(ID))
- r = marshalString(r, keyBytes)
- return ret
+ return Marshal(&w)
}
-func (key *ecdsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
+func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
+ if sig.Format != key.Type() {
+ return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
+ }
+
h := ecHash(key.Curve).New()
h.Write(data)
digest := h.Sum(nil)
@@ -464,15 +481,19 @@
// The ecdsa_signature_blob value has the following specific encoding:
// mpint r
// mpint s
- r, rest, ok := parseInt(sigBlob)
- if !ok {
- return false
+ var ecSig struct {
+ R *big.Int
+ S *big.Int
}
- s, rest, ok := parseInt(rest)
- if !ok || len(rest) > 0 {
- return false
+
+ if err := Unmarshal(sig.Blob, &ecSig); err != nil {
+ return err
}
- return ecdsa.Verify((*ecdsa.PublicKey)(key), digest, r, s)
+
+ if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
+ return nil
+ }
+ return errors.New("ssh: signature did not verify")
}
type ecdsaPrivateKey struct {
@@ -483,7 +504,7 @@
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey)
}
-func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := ecHash(k.PrivateKey.PublicKey.Curve).New()
h.Write(data)
digest := h.Sum(nil)
@@ -495,10 +516,13 @@
sig := make([]byte, intLength(r)+intLength(s))
rest := marshalInt(sig, r)
marshalInt(rest, s)
- return sig, nil
+ return &Signature{
+ Format: k.PublicKey().Type(),
+ Blob: sig,
+ }, nil
}
-// NewPrivateKey takes a pointer to rsa, dsa or ecdsa PrivateKey
+// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
// returns a corresponding Signer instance. EC keys should use P256,
// P384 or P521.
func NewSignerFromKey(k interface{}) (Signer, error) {
@@ -540,54 +564,49 @@
return sshKey, nil
}
-// ParsePublicKey parses a PEM encoded private key. It supports
-// PKCS#1, RSA, DSA and ECDSA private keys.
+// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
+// the same keys as ParseRawPrivateKey.
func ParsePrivateKey(pemBytes []byte) (Signer, error) {
+ key, err := ParseRawPrivateKey(pemBytes)
+ if err != nil {
+ return nil, err
+ }
+
+ return NewSignerFromKey(key)
+}
+
+// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
+// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
+func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
- var rawkey interface{}
switch block.Type {
case "RSA PRIVATE KEY":
- rsa, err := x509.ParsePKCS1PrivateKey(block.Bytes)
- if err != nil {
- return nil, err
- }
- rawkey = rsa
+ return x509.ParsePKCS1PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
- ec, err := x509.ParseECPrivateKey(block.Bytes)
- if err != nil {
- return nil, err
- }
- rawkey = ec
+ return x509.ParseECPrivateKey(block.Bytes)
case "DSA PRIVATE KEY":
- ec, err := parseDSAPrivate(block.Bytes)
- if err != nil {
- return nil, err
- }
- rawkey = ec
+ return ParseDSAPrivateKey(block.Bytes)
default:
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
-
- return NewSignerFromKey(rawkey)
}
-// parseDSAPrivate parses a DSA key in ASN.1 DER encoding, as
-// documented in the OpenSSL DSA manpage.
-// TODO(hanwen): move this in to crypto/x509 after the Go 1.2 freeze.
-func parseDSAPrivate(p []byte) (*dsa.PrivateKey, error) {
- k := struct {
+// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
+// specified by the OpenSSL DSA man page.
+func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
+ var k struct {
Version int
P *big.Int
Q *big.Int
G *big.Int
Priv *big.Int
Pub *big.Int
- }{}
- rest, err := asn1.Unmarshal(p, &k)
+ }
+ rest, err := asn1.Unmarshal(der, &k)
if err != nil {
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error())
}
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index 3c4b735..cd49565 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -1,66 +1,25 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
package ssh
import (
+ "bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
+ "encoding/base64"
+ "fmt"
"reflect"
"strings"
"testing"
+
+ "code.google.com/p/go.crypto/ssh/testdata"
)
-var (
- ecdsaKey Signer
- ecdsa384Key Signer
- ecdsa521Key Signer
- testCertKey Signer
-)
-
-type testSigner struct {
- Signer
- pub PublicKey
-}
-
-func (ts *testSigner) PublicKey() PublicKey {
- if ts.pub != nil {
- return ts.pub
- }
- return ts.Signer.PublicKey()
-}
-
-func init() {
- raw256, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
- ecdsaKey, _ = NewSignerFromKey(raw256)
-
- raw384, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
- ecdsa384Key, _ = NewSignerFromKey(raw384)
-
- raw521, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
- ecdsa521Key, _ = NewSignerFromKey(raw521)
-
- // Create a cert and sign it for use in tests.
- testCert := &OpenSSHCertV01{
- Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
- Key: ecdsaKey.PublicKey(),
- ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
- ValidAfter: 0, // unix epoch
- ValidBefore: maxUint64, // The end of currently representable time.
- Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
- SignatureKey: rsaKey.PublicKey(),
- }
- sigBytes, _ := rsaKey.Sign(rand.Reader, testCert.BytesForSigning())
- testCert.Signature = &signature{
- Format: testCert.SignatureKey.PublicKeyAlgo(),
- Blob: sigBytes,
- }
- testCertKey = &testSigner{
- Signer: ecdsaKey,
- pub: testCert,
- }
-}
-
func rawKey(pub PublicKey) interface{} {
switch k := pub.(type) {
case *rsaPublicKey:
@@ -69,23 +28,18 @@
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
- case *OpenSSHCertV01:
+ case *Certificate:
return k
}
panic("unknown key type")
}
func TestKeyMarshalParse(t *testing.T) {
- keys := []Signer{rsaKey, dsaKey, ecdsaKey, ecdsa384Key, ecdsa521Key, testCertKey}
- for _, priv := range keys {
+ for _, priv := range testSigners {
pub := priv.PublicKey()
- roundtrip, rest, ok := ParsePublicKey(MarshalPublicKey(pub))
- if !ok {
- t.Errorf("ParsePublicKey(%T) failed", pub)
- }
-
- if len(rest) > 0 {
- t.Errorf("ParsePublicKey(%T): trailing junk", pub)
+ roundtrip, err := ParsePublicKey(pub.Marshal())
+ if err != nil {
+ t.Errorf("ParsePublicKey(%T): %v", pub, err)
}
k1 := rawKey(pub)
@@ -113,9 +67,12 @@
}
func TestNewPublicKey(t *testing.T) {
- keys := []Signer{rsaKey, dsaKey, ecdsaKey}
- for _, k := range keys {
+ for _, k := range testSigners {
raw := rawKey(k.PublicKey())
+ // Skip certificates, as NewPublicKey does not support them.
+ if _, ok := raw.(*Certificate); ok {
+ continue
+ }
pub, err := NewPublicKey(raw)
if err != nil {
t.Errorf("NewPublicKey(%#v): %v", raw, err)
@@ -127,8 +84,7 @@
}
func TestKeySignVerify(t *testing.T) {
- keys := []Signer{rsaKey, dsaKey, ecdsaKey, testCertKey}
- for _, priv := range keys {
+ for _, priv := range testSigners {
pub := priv.PublicKey()
data := []byte("sign me")
@@ -137,19 +93,20 @@
t.Fatalf("Sign(%T): %v", priv, err)
}
- if !pub.Verify(data, sig) {
- t.Errorf("publicKey.Verify(%T) failed", priv)
+ if err := pub.Verify(data, sig); err != nil {
+ t.Errorf("publicKey.Verify(%T): %v", priv, err)
+ }
+ sig.Blob[5]++
+ if err := pub.Verify(data, sig); err == nil {
+ t.Errorf("publicKey.Verify on broken sig did not fail")
}
}
}
func TestParseRSAPrivateKey(t *testing.T) {
- key, err := ParsePrivateKey([]byte(testServerPrivateKey))
- if err != nil {
- t.Fatalf("ParsePrivateKey: %v", err)
- }
+ key := testPrivateKeys["rsa"]
- rsa, ok := key.(*rsaPrivateKey)
+ rsa, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
}
@@ -160,21 +117,11 @@
}
func TestParseECPrivateKey(t *testing.T) {
- // Taken from the data in test/ .
- pem := []byte(`-----BEGIN EC PRIVATE KEY-----
-MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
-AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
-6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
------END EC PRIVATE KEY-----`)
+ key := testPrivateKeys["ecdsa"]
- key, err := ParsePrivateKey(pem)
- if err != nil {
- t.Fatalf("ParsePrivateKey: %v", err)
- }
-
- ecKey, ok := key.(*ecdsaPrivateKey)
+ ecKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
- t.Fatalf("got %T, want *ecdsaPrivateKey", ecKey)
+ t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey)
}
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
@@ -182,22 +129,11 @@
}
}
-// ssh-keygen -t dsa -f /tmp/idsa.pem
-var dsaPEM = `-----BEGIN DSA PRIVATE KEY-----
-MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
-lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
-EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
-nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
-2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
-juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
-FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
-DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
-nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
-Fmsr0W6fHB9nhS4/UXM8
------END DSA PRIVATE KEY-----`
-
func TestParseDSA(t *testing.T) {
- s, err := ParsePrivateKey([]byte(dsaPEM))
+ // We actually exercise the ParsePrivateKey codepath here, as opposed to
+ // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
+ // uses.
+ s, err := ParsePrivateKey(testdata.PEMBytes["dsa"])
if err != nil {
t.Fatalf("ParsePrivateKey returned error: %s", err)
}
@@ -208,7 +144,163 @@
t.Fatalf("dsa.Sign: %v", err)
}
- if !s.PublicKey().Verify(data, sig) {
- t.Error("Verify failed.")
+ if err := s.PublicKey().Verify(data, sig); err != nil {
+ t.Errorf("Verify failed: %v", err)
+ }
+}
+
+// Tests for authorized_keys parsing.
+
+// getTestKey returns a public key, and its base64 encoding.
+func getTestKey() (PublicKey, string) {
+ k := testPublicKeys["rsa"]
+
+ b := &bytes.Buffer{}
+ e := base64.NewEncoder(base64.StdEncoding, b)
+ e.Write(k.Marshal())
+ e.Close()
+
+ return k, b.String()
+}
+
+func TestMarshalParsePublicKey(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized)
+
+ authKeys := MarshalAuthorizedKey(pub)
+ actualFields := strings.Fields(string(authKeys))
+ if len(actualFields) == 0 {
+ t.Fatalf("failed authKeys: %v", authKeys)
+ }
+
+ // drop the comment
+ expectedFields := strings.Fields(line)[0:2]
+
+ if !reflect.DeepEqual(actualFields, expectedFields) {
+ t.Errorf("got %v, expected %v", actualFields, expectedFields)
+ }
+
+ actPub, _, _, _, err := ParseAuthorizedKey([]byte(line))
+ if err != nil {
+ t.Fatalf("cannot parse %v: %v", line, err)
+ }
+ if !reflect.DeepEqual(actPub, pub) {
+ t.Errorf("got %v, expected %v", actPub, pub)
+ }
+}
+
+type authResult struct {
+ pubKey PublicKey
+ options []string
+ comments string
+ rest string
+ ok bool
+}
+
+func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
+ rest := authKeys
+ var values []authResult
+ for len(rest) > 0 {
+ var r authResult
+ var err error
+ r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
+ r.ok = (err == nil)
+ t.Log(err)
+ r.rest = string(rest)
+ values = append(values, r)
+ }
+
+ if !reflect.DeepEqual(values, expected) {
+ t.Errorf("got %#v, expected %#v", values, expected)
+ }
+}
+
+func TestAuthorizedKeyBasic(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ line := "ssh-rsa " + pubSerialized + " user@host"
+ testAuthorizedKeys(t, []byte(line),
+ []authResult{
+ {pub, nil, "user@host", "", true},
+ })
+}
+
+func TestAuth(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ authWithOptions := []string{
+ `# comments to ignore before any keys...`,
+ ``,
+ `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`,
+ `# comments to ignore, along with a blank line`,
+ ``,
+ `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`,
+ ``,
+ `# more comments, plus a invalid entry`,
+ `ssh-rsa data-that-will-not-parse user@host3`,
+ }
+ for _, eol := range []string{"\n", "\r\n"} {
+ authOptions := strings.Join(authWithOptions, eol)
+ rest2 := strings.Join(authWithOptions[3:], eol)
+ rest3 := strings.Join(authWithOptions[6:], eol)
+ testAuthorizedKeys(t, []byte(authOptions), []authResult{
+ {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
+ {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
+ {nil, nil, "", "", false},
+ })
+ }
+}
+
+func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
+ testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
+ {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
+ })
+}
+
+func TestAuthWithQuotedCommaInEnv(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
+ testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
+ {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
+ })
+}
+
+func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`)
+ authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
+ testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
+ {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
+ })
+
+ testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
+ {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
+ })
+}
+
+func TestAuthWithInvalidSpace(t *testing.T) {
+ _, pubSerialized := getTestKey()
+ authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
+#more to follow but still no valid keys`)
+ testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
+ {nil, nil, "", "", false},
+ })
+}
+
+func TestAuthWithMissingQuote(t *testing.T) {
+ pub, pubSerialized := getTestKey()
+ authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
+env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
+
+ testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
+ {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
+ })
+}
+
+func TestInvalidEntry(t *testing.T) {
+ authInvalid := []byte(`ssh-rsa`)
+ _, _, _, _, err := ParseAuthorizedKey(authInvalid)
+ if err == nil {
+ t.Errorf("got valid entry for %q", authInvalid)
}
}
diff --git a/ssh/mac.go b/ssh/mac.go
index 6862d3e..aff4042 100644
--- a/ssh/mac.go
+++ b/ssh/mac.go
@@ -43,11 +43,6 @@
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
-// Specifies a default set of MAC algorithms and a preference order.
-// This is based on RFC 4253, section 6.4, with the removal of the
-// hmac-md5 variants as they have reached the end of their useful life.
-var DefaultMACOrder = []string{"hmac-sha1", "hmac-sha1-96"}
-
var macModes = map[string]*macMode{
"hmac-sha1": {20, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key)
diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go
index ec1b854..69217a4 100644
--- a/ssh/mempipe_test.go
+++ b/ssh/mempipe_test.go
@@ -36,17 +36,23 @@
}
}
-func (t *memTransport) Close() error {
- t.write.Lock()
- defer t.write.Unlock()
- if t.write.eof {
+func (t *memTransport) closeSelf() error {
+ t.Lock()
+ defer t.Unlock()
+ if t.eof {
return io.EOF
}
- t.write.eof = true
- t.write.Cond.Broadcast()
+ t.eof = true
+ t.Cond.Broadcast()
return nil
}
+func (t *memTransport) Close() error {
+ err := t.write.closeSelf()
+ t.closeSelf()
+ return err
+}
+
func (t *memTransport) writePacket(p []byte) error {
t.write.Lock()
defer t.write.Unlock()
diff --git a/ssh/messages.go b/ssh/messages.go
index 94c3ea0..f9e44bb 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -7,58 +7,25 @@
import (
"bytes"
"encoding/binary"
+ "errors"
+ "fmt"
"io"
"math/big"
"reflect"
+ "strconv"
)
// These are SSH message type numbers. They are scattered around several
// documents but many were taken from [SSH-PARAMETERS].
const (
- msgDisconnect = 1
- msgIgnore = 2
- msgUnimplemented = 3
- msgDebug = 4
- msgServiceRequest = 5
- msgServiceAccept = 6
-
- msgKexInit = 20
- msgNewKeys = 21
-
- // Diffie-Helman
- msgKexDHInit = 30
- msgKexDHReply = 31
-
- msgKexECDHInit = 30
- msgKexECDHReply = 31
+ msgIgnore = 2
+ msgUnimplemented = 3
+ msgDebug = 4
+ msgNewKeys = 21
// Standard authentication messages
- msgUserAuthRequest = 50
- msgUserAuthFailure = 51
- msgUserAuthSuccess = 52
- msgUserAuthBanner = 53
- msgUserAuthPubKeyOk = 60
-
- // Method specific messages
- msgUserAuthInfoRequest = 60
- msgUserAuthInfoResponse = 61
-
- msgGlobalRequest = 80
- msgRequestSuccess = 81
- msgRequestFailure = 82
-
- // Channel manipulation
- msgChannelOpen = 90
- msgChannelOpenConfirm = 91
- msgChannelOpenFailure = 92
- msgChannelWindowAdjust = 93
- msgChannelData = 94
- msgChannelExtendedData = 95
- msgChannelEOF = 96
- msgChannelClose = 97
- msgChannelRequest = 98
- msgChannelSuccess = 99
- msgChannelFailure = 100
+ msgUserAuthSuccess = 52
+ msgUserAuthBanner = 53
)
// SSH messages:
@@ -69,15 +36,25 @@
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
// See RFC 4253, section 11.1.
+const msgDisconnect = 1
+
+// disconnectMsg is the message that signals a disconnect. It is also
+// the error type returned from mux.Wait()
type disconnectMsg struct {
- Reason uint32
+ Reason uint32 `sshtype:"1"`
Message string
Language string
}
+func (d *disconnectMsg) Error() string {
+ return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
+}
+
// See RFC 4253, section 7.1.
+const msgKexInit = 20
+
type kexInitMsg struct {
- Cookie [16]byte
+ Cookie [16]byte `sshtype:"20"`
KexAlgos []string
ServerHostKeyAlgos []string
CiphersClientServer []string
@@ -93,53 +70,74 @@
}
// See RFC 4253, section 8.
+
+// Diffie-Helman
+const msgKexDHInit = 30
+
type kexDHInitMsg struct {
- X *big.Int
+ X *big.Int `sshtype:"30"`
}
+const msgKexECDHInit = 30
+
type kexECDHInitMsg struct {
- ClientPubKey []byte
+ ClientPubKey []byte `sshtype:"30"`
}
+const msgKexECDHReply = 31
+
type kexECDHReplyMsg struct {
- HostKey []byte
+ HostKey []byte `sshtype:"31"`
EphemeralPubKey []byte
Signature []byte
}
+const msgKexDHReply = 31
+
type kexDHReplyMsg struct {
- HostKey []byte
+ HostKey []byte `sshtype:"31"`
Y *big.Int
Signature []byte
}
// See RFC 4253, section 10.
+const msgServiceRequest = 5
+
type serviceRequestMsg struct {
- Service string
+ Service string `sshtype:"5"`
}
// See RFC 4253, section 10.
+const msgServiceAccept = 6
+
type serviceAcceptMsg struct {
- Service string
+ Service string `sshtype:"6"`
}
// See RFC 4252, section 5.
+const msgUserAuthRequest = 50
+
type userAuthRequestMsg struct {
- User string
+ User string `sshtype:"50"`
Service string
Method string
Payload []byte `ssh:"rest"`
}
// See RFC 4252, section 5.1
+const msgUserAuthFailure = 51
+
type userAuthFailureMsg struct {
- Methods []string
+ Methods []string `sshtype:"51"`
PartialSuccess bool
}
// See RFC 4256, section 3.2
+const msgUserAuthInfoRequest = 60
+const msgUserAuthInfoResponse = 61
+
type userAuthInfoRequestMsg struct {
- User string
+ User string `sshtype:"60"`
Instruction string
DeprecatedLanguage string
NumPrompts uint32
@@ -147,17 +145,24 @@
}
// See RFC 4254, section 5.1.
+const msgChannelOpen = 90
+
type channelOpenMsg struct {
- ChanType string
+ ChanType string `sshtype:"90"`
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
}
+const msgChannelExtendedData = 95
+const msgChannelData = 94
+
// See RFC 4254, section 5.1.
+const msgChannelOpenConfirm = 91
+
type channelOpenConfirmMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"91"`
MyId uint32
MyWindow uint32
MaxPacketSize uint32
@@ -165,172 +170,239 @@
}
// See RFC 4254, section 5.1.
+const msgChannelOpenFailure = 92
+
type channelOpenFailureMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"92"`
Reason RejectionReason
Message string
Language string
}
+const msgChannelRequest = 98
+
type channelRequestMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
}
// See RFC 4254, section 5.4.
+const msgChannelSuccess = 99
+
type channelRequestSuccessMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
+const msgChannelFailure = 100
+
type channelRequestFailureMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
+const msgChannelClose = 97
+
type channelCloseMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
+const msgChannelEOF = 96
+
type channelEOFMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"96"`
}
// See RFC 4254, section 4
+const msgGlobalRequest = 80
+
type globalRequestMsg struct {
- Type string
+ Type string `sshtype:"80"`
WantReply bool
+ Data []byte `ssh:"rest"`
}
// See RFC 4254, section 4
+const msgRequestSuccess = 81
+
type globalRequestSuccessMsg struct {
- Data []byte `ssh:"rest"`
+ Data []byte `ssh:"rest" sshtype:"81"`
}
// See RFC 4254, section 4
+const msgRequestFailure = 82
+
type globalRequestFailureMsg struct {
- Data []byte `ssh:"rest"`
+ Data []byte `ssh:"rest" sshtype:"82"`
}
// See RFC 4254, section 5.2
+const msgChannelWindowAdjust = 93
+
type windowAdjustMsg struct {
- PeersId uint32
+ PeersId uint32 `sshtype:"93"`
AdditionalBytes uint32
}
// See RFC 4252, section 7
+const msgUserAuthPubKeyOk = 60
+
type userAuthPubKeyOkMsg struct {
- Algo string
- PubKey string
+ Algo string `sshtype:"60"`
+ PubKey []byte
}
-// unmarshal parses the SSH wire data in packet into out using
-// reflection. expectedType, if non-zero, is the SSH message type that
-// the packet is expected to start with. unmarshal either returns nil
-// on success, or a ParseError or UnexpectedMessageError on error.
-func unmarshal(out interface{}, packet []byte, expectedType uint8) error {
- if len(packet) == 0 {
- return ParseError{expectedType}
+// typeTag returns the type byte for the given type. The type should
+// be struct.
+func typeTag(structType reflect.Type) byte {
+ var tag byte
+ var tagStr string
+ tagStr = structType.Field(0).Tag.Get("sshtype")
+ i, err := strconv.Atoi(tagStr)
+ if err == nil {
+ tag = byte(i)
}
- if expectedType > 0 {
- if packet[0] != expectedType {
- return UnexpectedMessageError{expectedType, packet[0]}
- }
- packet = packet[1:]
- }
+ return tag
+}
+func fieldError(t reflect.Type, field int, problem string) error {
+ if problem != "" {
+ problem = ": " + problem
+ }
+ return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
+}
+
+var errShortRead = errors.New("ssh: short read")
+
+// Unmarshal parses data in SSH wire format into a structure. The out
+// argument should be a pointer to struct. If the first member of the
+// struct has the "sshtype" tag set to a number in decimal, the packet
+// must start that number. In case of error, Unmarshal returns a
+// ParseError or UnexpectedMessageError.
+func Unmarshal(data []byte, out interface{}) error {
v := reflect.ValueOf(out).Elem()
structType := v.Type()
+ expectedType := typeTag(structType)
+ if len(data) == 0 {
+ return parseError(expectedType)
+ }
+ if expectedType > 0 {
+ if data[0] != expectedType {
+ return unexpectedMessageError(expectedType, data[0])
+ }
+ data = data[1:]
+ }
+
var ok bool
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
t := field.Type()
switch t.Kind() {
case reflect.Bool:
- if len(packet) < 1 {
- return ParseError{expectedType}
+ if len(data) < 1 {
+ return errShortRead
}
- field.SetBool(packet[0] != 0)
- packet = packet[1:]
+ field.SetBool(data[0] != 0)
+ data = data[1:]
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
- panic("array of non-uint8")
+ return fieldError(structType, i, "array of unsupported type")
}
- if len(packet) < t.Len() {
- return ParseError{expectedType}
+ if len(data) < t.Len() {
+ return errShortRead
}
for j, n := 0, t.Len(); j < n; j++ {
- field.Index(j).Set(reflect.ValueOf(packet[j]))
+ field.Index(j).Set(reflect.ValueOf(data[j]))
}
- packet = packet[t.Len():]
+ data = data[t.Len():]
+ case reflect.Uint64:
+ var u64 uint64
+ if u64, data, ok = parseUint64(data); !ok {
+ return errShortRead
+ }
+ field.SetUint(u64)
case reflect.Uint32:
var u32 uint32
- if u32, packet, ok = parseUint32(packet); !ok {
- return ParseError{expectedType}
+ if u32, data, ok = parseUint32(data); !ok {
+ return errShortRead
}
field.SetUint(uint64(u32))
+ case reflect.Uint8:
+ if len(data) < 1 {
+ return errShortRead
+ }
+ field.SetUint(uint64(data[0]))
+ data = data[1:]
case reflect.String:
var s []byte
- if s, packet, ok = parseString(packet); !ok {
- return ParseError{expectedType}
+ if s, data, ok = parseString(data); !ok {
+ return fieldError(structType, i, "")
}
field.SetString(string(s))
case reflect.Slice:
switch t.Elem().Kind() {
case reflect.Uint8:
if structType.Field(i).Tag.Get("ssh") == "rest" {
- field.Set(reflect.ValueOf(packet))
- packet = nil
+ field.Set(reflect.ValueOf(data))
+ data = nil
} else {
var s []byte
- if s, packet, ok = parseString(packet); !ok {
- return ParseError{expectedType}
+ if s, data, ok = parseString(data); !ok {
+ return errShortRead
}
field.Set(reflect.ValueOf(s))
}
case reflect.String:
var nl []string
- if nl, packet, ok = parseNameList(packet); !ok {
- return ParseError{expectedType}
+ if nl, data, ok = parseNameList(data); !ok {
+ return errShortRead
}
field.Set(reflect.ValueOf(nl))
default:
- panic("slice of unknown type")
+ return fieldError(structType, i, "slice of unsupported type")
}
case reflect.Ptr:
if t == bigIntType {
var n *big.Int
- if n, packet, ok = parseInt(packet); !ok {
- return ParseError{expectedType}
+ if n, data, ok = parseInt(data); !ok {
+ return errShortRead
}
field.Set(reflect.ValueOf(n))
} else {
- panic("pointer to unknown type")
+ return fieldError(structType, i, "pointer to unsupported type")
}
default:
- panic("unknown type")
+ return fieldError(structType, i, "unsupported type")
}
}
- if len(packet) != 0 {
- return ParseError{expectedType}
+ if len(data) != 0 {
+ return parseError(expectedType)
}
return nil
}
-// marshal serializes the message in msg. The given message type is
-// prepended if it is non-zero.
-func marshal(msgType uint8, msg interface{}) []byte {
+// Marshal serializes the message in msg to SSH wire format. The msg
+// argument should be a struct or pointer to struct. If the first
+// member has the "sshtype" tag set to a number in decimal, that
+// number is prepended to the result. If the last of member has the
+// "ssh" tag set to "rest", its contents are appended to the output.
+func Marshal(msg interface{}) []byte {
out := make([]byte, 0, 64)
+ return marshalStruct(out, msg)
+}
+
+func marshalStruct(out []byte, msg interface{}) []byte {
+ v := reflect.Indirect(reflect.ValueOf(msg))
+ msgType := typeTag(v.Type())
if msgType > 0 {
out = append(out, msgType)
}
- v := reflect.ValueOf(msg)
for i, n := 0, v.NumField(); i < n; i++ {
field := v.Field(i)
switch t := field.Type(); t.Kind() {
@@ -342,13 +414,17 @@
out = append(out, v)
case reflect.Array:
if t.Elem().Kind() != reflect.Uint8 {
- panic("array of non-uint8")
+ panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
}
for j, l := 0, t.Len(); j < l; j++ {
out = append(out, uint8(field.Index(j).Uint()))
}
case reflect.Uint32:
out = appendU32(out, uint32(field.Uint()))
+ case reflect.Uint64:
+ out = appendU64(out, uint64(field.Uint()))
+ case reflect.Uint8:
+ out = append(out, uint8(field.Uint()))
case reflect.String:
s := field.String()
out = appendInt(out, len(s))
@@ -375,7 +451,7 @@
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
}
default:
- panic("slice of unknown type")
+ panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
}
case reflect.Ptr:
if t == bigIntType {
@@ -393,7 +469,7 @@
out = out[:oldLength+needed]
marshalInt(out[oldLength:], n)
} else {
- panic("pointer to unknown type")
+ panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
}
}
}
@@ -477,17 +553,6 @@
return binary.BigEndian.Uint64(in), in[8:], true
}
-func nameListLength(namelist []string) int {
- length := 4 /* uint32 length prefix */
- for i, name := range namelist {
- if i != 0 {
- length++ /* comma */
- }
- length += len(name)
- }
- return length
-}
-
func intLength(n *big.Int) int {
length := 4 /* length bytes */
if n.Sign() < 0 {
@@ -650,9 +715,9 @@
case msgChannelFailure:
msg = new(channelRequestFailureMsg)
default:
- return nil, UnexpectedMessageError{0, packet[0]}
+ return nil, unexpectedMessageError(0, packet[0])
}
- if err := unmarshal(msg, packet, packet[0]); err != nil {
+ if err := Unmarshal(packet, msg); err != nil {
return nil, err
}
return msg, nil
diff --git a/ssh/messages_test.go b/ssh/messages_test.go
index ec1d7be..f14c8a2 100644
--- a/ssh/messages_test.go
+++ b/ssh/messages_test.go
@@ -5,6 +5,7 @@
package ssh
import (
+ "bytes"
"math/big"
"math/rand"
"reflect"
@@ -32,48 +33,62 @@
}
}
-var messageTypes = []interface{}{
- &kexInitMsg{},
- &kexDHInitMsg{},
- &serviceRequestMsg{},
- &serviceAcceptMsg{},
- &userAuthRequestMsg{},
- &channelOpenMsg{},
- &channelOpenConfirmMsg{},
- &channelOpenFailureMsg{},
- &channelRequestMsg{},
- &channelRequestSuccessMsg{},
+type msgAllTypes struct {
+ Bool bool `sshtype:"21"`
+ Array [16]byte
+ Uint64 uint64
+ Uint32 uint32
+ Uint8 uint8
+ String string
+ Strings []string
+ Bytes []byte
+ Int *big.Int
+ Rest []byte `ssh:"rest"`
+}
+
+func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
+ m := &msgAllTypes{}
+ m.Bool = rand.Intn(2) == 1
+ randomBytes(m.Array[:], rand)
+ m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
+ m.Uint32 = uint32(rand.Intn(1 << 32))
+ m.Uint8 = uint8(rand.Intn(1 << 8))
+ m.String = string(m.Array[:])
+ m.Strings = randomNameList(rand)
+ m.Bytes = m.Array[:]
+ m.Int = randomInt(rand)
+ m.Rest = m.Array[:]
+ return reflect.ValueOf(m)
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
- for i, iface := range messageTypes {
- ty := reflect.ValueOf(iface).Type()
+ iface := &msgAllTypes{}
+ ty := reflect.ValueOf(iface).Type()
- n := 100
- if testing.Short() {
- n = 5
+ n := 100
+ if testing.Short() {
+ n = 5
+ }
+ for j := 0; j < n; j++ {
+ v, ok := quick.Value(ty, rand)
+ if !ok {
+ t.Errorf("failed to create value")
+ break
}
- for j := 0; j < n; j++ {
- v, ok := quick.Value(ty, rand)
- if !ok {
- t.Errorf("#%d: failed to create value", i)
- break
- }
- m1 := v.Elem().Interface()
- m2 := iface
+ m1 := v.Elem().Interface()
+ m2 := iface
- marshaled := marshal(msgIgnore, m1)
- if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
- t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
- break
- }
+ marshaled := Marshal(m1)
+ if err := Unmarshal(marshaled, m2); err != nil {
+ t.Errorf("Unmarshal %#v: %s", m1, err)
+ break
+ }
- if !reflect.DeepEqual(v.Interface(), m2) {
- t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
- break
- }
+ if !reflect.DeepEqual(v.Interface(), m2) {
+ t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
+ break
}
}
}
@@ -81,33 +96,37 @@
func TestUnmarshalEmptyPacket(t *testing.T) {
var b []byte
var m channelRequestSuccessMsg
- err := unmarshal(&m, b, msgChannelRequest)
- want := ParseError{msgChannelRequest}
- if _, ok := err.(ParseError); !ok {
- t.Fatalf("got %T, want %T", err, want)
- }
- if got := err.(ParseError); want != got {
- t.Fatal("got %#v, want %#v", got, want)
+ if err := Unmarshal(b, &m); err == nil {
+ t.Fatalf("unmarshal of empty slice succeeded")
}
}
func TestUnmarshalUnexpectedPacket(t *testing.T) {
type S struct {
- I uint32
+ I uint32 `sshtype:"43"`
S string
B bool
}
- s := S{42, "hello", true}
- packet := marshal(42, s)
+ s := S{11, "hello", true}
+ packet := Marshal(s)
+ packet[0] = 42
roundtrip := S{}
- err := unmarshal(&roundtrip, packet, 43)
+ err := Unmarshal(packet, &roundtrip)
if err == nil {
t.Fatal("expected error, not nil")
}
- want := UnexpectedMessageError{43, 42}
- if got, ok := err.(UnexpectedMessageError); !ok || want != got {
- t.Fatal("expected %q, got %q", want, got)
+}
+
+func TestMarshalPtr(t *testing.T) {
+ s := struct {
+ S string
+ }{"hello"}
+
+ m1 := Marshal(s)
+ m2 := Marshal(&s)
+ if !bytes.Equal(m1, m2) {
+ t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
}
}
@@ -119,9 +138,9 @@
}
s := S{42, "hello", true}
- packet := marshal(0, s)
+ packet := Marshal(s)
roundtrip := S{}
- unmarshal(&roundtrip, packet, 0)
+ Unmarshal(packet, &roundtrip)
if !reflect.DeepEqual(s, roundtrip) {
t.Errorf("got %#v, want %#v", roundtrip, s)
@@ -133,7 +152,7 @@
I uint32
}
s := S2{42}
- packet := marshal(0, s)
+ packet := Marshal(s)
i, rest, ok := parseUint32(packet)
if len(rest) > 0 || !ok {
t.Errorf("parseInt(%q): parse error", packet)
@@ -190,43 +209,36 @@
return reflect.ValueOf(dhi)
}
-// TODO(dfc) maybe this can be removed in the future if testing/quick can handle
-// derived basic types.
-func (RejectionReason) Generate(rand *rand.Rand, size int) reflect.Value {
- m := RejectionReason(Prohibited)
- return reflect.ValueOf(m)
-}
-
var (
_kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
- _kexInit = marshal(msgKexInit, _kexInitMsg)
- _kexDHInit = marshal(msgKexDHInit, _kexDHInitMsg)
+ _kexInit = Marshal(_kexInitMsg)
+ _kexDHInit = Marshal(_kexDHInitMsg)
)
func BenchmarkMarshalKexInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
- marshal(msgKexInit, _kexInitMsg)
+ Marshal(_kexInitMsg)
}
}
func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
m := new(kexInitMsg)
for i := 0; i < b.N; i++ {
- unmarshal(m, _kexInit, msgKexInit)
+ Unmarshal(_kexInit, m)
}
}
func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
for i := 0; i < b.N; i++ {
- marshal(msgKexDHInit, _kexDHInitMsg)
+ Marshal(_kexDHInitMsg)
}
}
func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
m := new(kexDHInitMsg)
for i := 0; i < b.N; i++ {
- unmarshal(m, _kexDHInit, msgKexDHInit)
+ Unmarshal(_kexDHInit, m)
}
}
diff --git a/ssh/mux.go b/ssh/mux.go
new file mode 100644
index 0000000..5af7c16
--- /dev/null
+++ b/ssh/mux.go
@@ -0,0 +1,352 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "log"
+ "sync"
+ "sync/atomic"
+)
+
+// debugMux, if set, causes messages in the connection protocol to be
+// logged.
+const debugMux = false
+
+// chanList is a thread safe channel list.
+type chanList struct {
+ // protects concurrent access to chans
+ sync.Mutex
+
+ // chans are indexed by the local id of the channel, which the
+ // other side should send in the PeersId field.
+ chans []*channel
+
+ // This is a debugging aid: it offsets all IDs by this
+ // amount. This helps distinguish otherwise identical
+ // server/client muxes
+ offset uint32
+}
+
+// Assigns a channel ID to the given channel.
+func (c *chanList) add(ch *channel) uint32 {
+ c.Lock()
+ defer c.Unlock()
+ for i := range c.chans {
+ if c.chans[i] == nil {
+ c.chans[i] = ch
+ return uint32(i) + c.offset
+ }
+ }
+ c.chans = append(c.chans, ch)
+ return uint32(len(c.chans)-1) + c.offset
+}
+
+// getChan returns the channel for the given ID.
+func (c *chanList) getChan(id uint32) *channel {
+ id -= c.offset
+
+ c.Lock()
+ defer c.Unlock()
+ if id < uint32(len(c.chans)) {
+ return c.chans[id]
+ }
+ return nil
+}
+
+func (c *chanList) remove(id uint32) {
+ id -= c.offset
+ c.Lock()
+ if id < uint32(len(c.chans)) {
+ c.chans[id] = nil
+ }
+ c.Unlock()
+}
+
+// dropAll forgets all channels it knows, returning them in a slice.
+func (c *chanList) dropAll() []*channel {
+ c.Lock()
+ defer c.Unlock()
+ var r []*channel
+
+ for _, ch := range c.chans {
+ if ch == nil {
+ continue
+ }
+ r = append(r, ch)
+ }
+ c.chans = nil
+ return r
+}
+
+// mux represents the state for the SSH connection protocol, which
+// multiplexes many channels onto a single packet transport.
+type mux struct {
+ conn packetConn
+ chanList chanList
+
+ incomingChannels chan NewChannel
+
+ globalSentMu sync.Mutex
+ globalResponses chan interface{}
+ incomingRequests chan *Request
+
+ errCond *sync.Cond
+ err error
+}
+
+// Each new chanList instantiation has a different offset.
+var globalOff uint32
+
+func (m *mux) Wait() error {
+ m.errCond.L.Lock()
+ defer m.errCond.L.Unlock()
+ for m.err == nil {
+ m.errCond.Wait()
+ }
+ return m.err
+}
+
+// newMux returns a mux that runs over the given connection.
+func newMux(p packetConn) *mux {
+ m := &mux{
+ conn: p,
+ incomingChannels: make(chan NewChannel, 16),
+ globalResponses: make(chan interface{}, 1),
+ incomingRequests: make(chan *Request, 16),
+ errCond: newCond(),
+ }
+ m.chanList.offset = atomic.AddUint32(&globalOff, 1)
+ go m.loop()
+ return m
+}
+
+func (m *mux) sendMessage(msg interface{}) error {
+ p := Marshal(msg)
+ return m.conn.writePacket(p)
+}
+
+func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
+ if wantReply {
+ m.globalSentMu.Lock()
+ defer m.globalSentMu.Unlock()
+ }
+
+ if err := m.sendMessage(globalRequestMsg{
+ Type: name,
+ WantReply: wantReply,
+ Data: payload,
+ }); err != nil {
+ return false, nil, err
+ }
+
+ if !wantReply {
+ return false, nil, nil
+ }
+
+ msg, ok := <-m.globalResponses
+ if !ok {
+ return false, nil, io.EOF
+ }
+ switch msg := msg.(type) {
+ case *globalRequestFailureMsg:
+ return false, msg.Data, nil
+ case *globalRequestSuccessMsg:
+ return true, msg.Data, nil
+ default:
+ return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
+ }
+}
+
+// ackRequest must be called after processing a global request that
+// has WantReply set.
+func (m *mux) ackRequest(ok bool, data []byte) error {
+ if ok {
+ return m.sendMessage(globalRequestSuccessMsg{Data: data})
+ }
+ return m.sendMessage(globalRequestFailureMsg{Data: data})
+}
+
+// TODO(hanwen): Disconnect is a transport layer message. We should
+// probably send and receive Disconnect somewhere in the transport
+// code.
+
+// Disconnect sends a disconnect message.
+func (m *mux) Disconnect(reason uint32, message string) error {
+ return m.sendMessage(disconnectMsg{
+ Reason: reason,
+ Message: message,
+ })
+}
+
+func (m *mux) Close() error {
+ return m.conn.Close()
+}
+
+// loop runs the connection machine. It will process packets until an
+// error is encountered. To synchronize on loop exit, use mux.Wait.
+func (m *mux) loop() {
+ var err error
+ for err == nil {
+ err = m.onePacket()
+ }
+
+ for _, ch := range m.chanList.dropAll() {
+ ch.close()
+ }
+
+ close(m.incomingChannels)
+ close(m.incomingRequests)
+ close(m.globalResponses)
+
+ m.conn.Close()
+
+ m.errCond.L.Lock()
+ m.err = err
+ m.errCond.Broadcast()
+ m.errCond.L.Unlock()
+
+ if debugMux {
+ log.Println("loop exit", err)
+ }
+}
+
+// onePacket reads and processes one packet.
+func (m *mux) onePacket() error {
+ packet, err := m.conn.readPacket()
+ if err != nil {
+ return err
+ }
+
+ if debugMux {
+ if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
+ log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
+ } else {
+ p, _ := decode(packet)
+ log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
+ }
+ }
+
+ switch packet[0] {
+ case msgNewKeys:
+ // Ignore notification of key change.
+ return nil
+ case msgDisconnect:
+ return m.handleDisconnect(packet)
+ case msgChannelOpen:
+ return m.handleChannelOpen(packet)
+ case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
+ return m.handleGlobalPacket(packet)
+ }
+
+ // assume a channel packet.
+ if len(packet) < 5 {
+ return parseError(packet[0])
+ }
+ id := binary.BigEndian.Uint32(packet[1:])
+ ch := m.chanList.getChan(id)
+ if ch == nil {
+ return fmt.Errorf("ssh: invalid channel %d", id)
+ }
+
+ return ch.handlePacket(packet)
+}
+
+func (m *mux) handleDisconnect(packet []byte) error {
+ var d disconnectMsg
+ if err := Unmarshal(packet, &d); err != nil {
+ return err
+ }
+
+ if debugMux {
+ log.Printf("caught disconnect: %v", d)
+ }
+ return &d
+}
+
+func (m *mux) handleGlobalPacket(packet []byte) error {
+ msg, err := decode(packet)
+ if err != nil {
+ return err
+ }
+
+ switch msg := msg.(type) {
+ case *globalRequestMsg:
+ m.incomingRequests <- &Request{
+ Type: msg.Type,
+ WantReply: msg.WantReply,
+ Payload: msg.Data,
+ mux: m,
+ }
+ case *globalRequestSuccessMsg, *globalRequestFailureMsg:
+ m.globalResponses <- msg
+ default:
+ panic(fmt.Sprintf("not a global message %#v", msg))
+ }
+
+ return nil
+}
+
+// handleChannelOpen schedules a channel to be Accept()ed.
+func (m *mux) handleChannelOpen(packet []byte) error {
+ var msg channelOpenMsg
+ if err := Unmarshal(packet, &msg); err != nil {
+ return err
+ }
+
+ if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+ failMsg := channelOpenFailureMsg{
+ PeersId: msg.PeersId,
+ Reason: ConnectionFailed,
+ Message: "invalid request",
+ Language: "en_US.UTF-8",
+ }
+ return m.sendMessage(failMsg)
+ }
+
+ c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
+ c.remoteId = msg.PeersId
+ c.maxRemotePayload = msg.MaxPacketSize
+ c.remoteWin.add(msg.PeersWindow)
+ m.incomingChannels <- c
+ return nil
+}
+
+func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
+ ch, err := m.openChannel(chanType, extra)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return ch, ch.incomingRequests, nil
+}
+
+func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
+ ch := m.newChannel(chanType, channelOutbound, extra)
+
+ ch.maxIncomingPayload = channelMaxPacket
+
+ open := channelOpenMsg{
+ ChanType: chanType,
+ PeersWindow: ch.myWindow,
+ MaxPacketSize: ch.maxIncomingPayload,
+ TypeSpecificData: extra,
+ PeersId: ch.localId,
+ }
+ if err := m.sendMessage(open); err != nil {
+ return nil, err
+ }
+
+ switch msg := (<-ch.msg).(type) {
+ case *channelOpenConfirmMsg:
+ return ch, nil
+ case *channelOpenFailureMsg:
+ return nil, &OpenChannelError{msg.Reason, msg.Message}
+ default:
+ return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
+ }
+}
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
new file mode 100644
index 0000000..e18afe7
--- /dev/null
+++ b/ssh/mux_test.go
@@ -0,0 +1,483 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "io"
+ "io/ioutil"
+ "sync"
+ "testing"
+)
+
+func muxPair() (*mux, *mux) {
+ a, b := memPipe()
+
+ s := newMux(a)
+ c := newMux(b)
+
+ return s, c
+}
+
+// Returns both ends of a channel, and the mux for the the 2nd
+// channel.
+func channelPair(t *testing.T) (*channel, *channel, *mux) {
+ c, s := muxPair()
+
+ res := make(chan *channel, 1)
+ go func() {
+ newCh, ok := <-s.incomingChannels
+ if !ok {
+ t.Fatalf("No incoming channel")
+ }
+ if newCh.ChannelType() != "chan" {
+ t.Fatalf("got type %q want chan", newCh.ChannelType())
+ }
+ ch, _, err := newCh.Accept()
+ if err != nil {
+ t.Fatalf("Accept %v", err)
+ }
+ res <- ch.(*channel)
+ }()
+
+ ch, err := c.openChannel("chan", nil)
+ if err != nil {
+ t.Fatalf("OpenChannel: %v", err)
+ }
+
+ return <-res, ch, c
+}
+
+func TestMuxReadWrite(t *testing.T) {
+ s, c, mux := channelPair(t)
+ defer s.Close()
+ defer c.Close()
+ defer mux.Close()
+
+ magic := "hello world"
+ magicExt := "hello stderr"
+ go func() {
+ _, err := s.Write([]byte(magic))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ _, err = s.Extended(1).Write([]byte(magicExt))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ err = s.Close()
+ if err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+ }()
+
+ var buf [1024]byte
+ n, err := c.Read(buf[:])
+ if err != nil {
+ t.Fatalf("server Read: %v", err)
+ }
+ got := string(buf[:n])
+ if got != magic {
+ t.Fatalf("server: got %q want %q", got, magic)
+ }
+
+ n, err = c.Extended(1).Read(buf[:])
+ if err != nil {
+ t.Fatalf("server Read: %v", err)
+ }
+
+ got = string(buf[:n])
+ if got != magicExt {
+ t.Fatalf("server: got %q want %q", got, magic)
+ }
+}
+
+func TestMuxChannelOverflow(t *testing.T) {
+ reader, writer, mux := channelPair(t)
+ defer reader.Close()
+ defer writer.Close()
+ defer mux.Close()
+
+ wDone := make(chan int, 1)
+ go func() {
+ if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+ t.Errorf("could not fill window: %v", err)
+ }
+ writer.Write(make([]byte, 1))
+ wDone <- 1
+ }()
+ writer.remoteWin.waitWriterBlocked()
+
+ // Send 1 byte.
+ packet := make([]byte, 1+4+4+1)
+ packet[0] = msgChannelData
+ marshalUint32(packet[1:], writer.remoteId)
+ marshalUint32(packet[5:], uint32(1))
+ packet[9] = 42
+
+ if err := writer.mux.conn.writePacket(packet); err != nil {
+ t.Errorf("could not send packet")
+ }
+ if _, err := reader.SendRequest("hello", true, nil); err == nil {
+ t.Errorf("SendRequest succeeded.")
+ }
+ <-wDone
+}
+
+func TestMuxChannelCloseWriteUnblock(t *testing.T) {
+ reader, writer, mux := channelPair(t)
+ defer reader.Close()
+ defer writer.Close()
+ defer mux.Close()
+
+ wDone := make(chan int, 1)
+ go func() {
+ if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+ t.Errorf("could not fill window: %v", err)
+ }
+ if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
+ t.Errorf("got %v, want EOF for unblock write", err)
+ }
+ wDone <- 1
+ }()
+
+ writer.remoteWin.waitWriterBlocked()
+ reader.Close()
+ <-wDone
+}
+
+func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
+ reader, writer, mux := channelPair(t)
+ defer reader.Close()
+ defer writer.Close()
+ defer mux.Close()
+
+ wDone := make(chan int, 1)
+ go func() {
+ if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+ t.Errorf("could not fill window: %v", err)
+ }
+ if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
+ t.Errorf("got %v, want EOF for unblock write", err)
+ }
+ wDone <- 1
+ }()
+
+ writer.remoteWin.waitWriterBlocked()
+ mux.Close()
+ <-wDone
+}
+
+func TestMuxReject(t *testing.T) {
+ client, server := muxPair()
+ defer server.Close()
+ defer client.Close()
+
+ go func() {
+ ch, ok := <-server.incomingChannels
+ if !ok {
+ t.Fatalf("Accept")
+ }
+ if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
+ t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+ }
+ ch.Reject(RejectionReason(42), "message")
+ }()
+
+ ch, err := client.openChannel("ch", []byte("extra"))
+ if ch != nil {
+ t.Fatal("openChannel not rejected")
+ }
+
+ ocf, ok := err.(*OpenChannelError)
+ if !ok {
+ t.Errorf("got %#v want *OpenChannelError", err)
+ } else if ocf.Reason != 42 || ocf.Message != "message" {
+ t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
+ }
+
+ want := "ssh: rejected: unknown reason 42 (message)"
+ if err.Error() != want {
+ t.Errorf("got %q, want %q", err.Error(), want)
+ }
+}
+
+func TestMuxChannelRequest(t *testing.T) {
+ client, server, mux := channelPair(t)
+ defer server.Close()
+ defer client.Close()
+ defer mux.Close()
+
+ var received int
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ for r := range server.incomingRequests {
+ received++
+ r.Reply(r.Type == "yes", nil)
+ }
+ wg.Done()
+ }()
+ _, err := client.SendRequest("yes", false, nil)
+ if err != nil {
+ t.Fatalf("SendRequest: %v", err)
+ }
+ ok, err := client.SendRequest("yes", true, nil)
+ if err != nil {
+ t.Fatalf("SendRequest: %v", err)
+ }
+
+ if !ok {
+ t.Errorf("SendRequest(yes): %v", ok)
+
+ }
+
+ ok, err = client.SendRequest("no", true, nil)
+ if err != nil {
+ t.Fatalf("SendRequest: %v", err)
+ }
+ if ok {
+ t.Errorf("SendRequest(no): %v", ok)
+
+ }
+
+ client.Close()
+ wg.Wait()
+
+ if received != 3 {
+ t.Errorf("got %d requests, want %d", received, 3)
+ }
+}
+
+func TestMuxGlobalRequest(t *testing.T) {
+ clientMux, serverMux := muxPair()
+ defer serverMux.Close()
+ defer clientMux.Close()
+
+ var seen bool
+ go func() {
+ for r := range serverMux.incomingRequests {
+ seen = seen || r.Type == "peek"
+ if r.WantReply {
+ err := r.Reply(r.Type == "yes",
+ append([]byte(r.Type), r.Payload...))
+ if err != nil {
+ t.Errorf("AckRequest: %v", err)
+ }
+ }
+ }
+ }()
+
+ _, _, err := clientMux.SendRequest("peek", false, nil)
+ if err != nil {
+ t.Errorf("SendRequest: %v", err)
+ }
+
+ ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
+ if !ok || string(data) != "yesa" || err != nil {
+ t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
+ ok, data, err)
+ }
+ if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
+ t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
+ ok, data, err)
+ }
+
+ if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
+ t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
+ ok, data, err)
+ }
+
+ clientMux.Disconnect(0, "")
+ if !seen {
+ t.Errorf("never saw 'peek' request")
+ }
+}
+
+func TestMuxGlobalRequestUnblock(t *testing.T) {
+ clientMux, serverMux := muxPair()
+ defer serverMux.Close()
+ defer clientMux.Close()
+
+ result := make(chan error, 1)
+ go func() {
+ _, _, err := clientMux.SendRequest("hello", true, nil)
+ result <- err
+ }()
+
+ <-serverMux.incomingRequests
+ serverMux.conn.Close()
+ err := <-result
+
+ if err != io.EOF {
+ t.Errorf("want EOF, got %v", io.EOF)
+ }
+}
+
+func TestMuxChannelRequestUnblock(t *testing.T) {
+ a, b, connB := channelPair(t)
+ defer a.Close()
+ defer b.Close()
+ defer connB.Close()
+
+ result := make(chan error, 1)
+ go func() {
+ _, err := a.SendRequest("hello", true, nil)
+ result <- err
+ }()
+
+ <-b.incomingRequests
+ connB.conn.Close()
+ err := <-result
+
+ if err != io.EOF {
+ t.Errorf("want EOF, got %v", err)
+ }
+}
+
+func TestMuxDisconnect(t *testing.T) {
+ a, b := muxPair()
+ defer a.Close()
+ defer b.Close()
+
+ go func() {
+ for r := range b.incomingRequests {
+ r.Reply(true, nil)
+ }
+ }()
+
+ a.Disconnect(42, "whatever")
+ ok, _, err := a.SendRequest("hello", true, nil)
+ if ok || err == nil {
+ t.Errorf("got reply after disconnecting")
+ }
+ err = b.Wait()
+ if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
+ t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
+ }
+}
+
+func TestMuxCloseChannel(t *testing.T) {
+ r, w, mux := channelPair(t)
+ defer mux.Close()
+ defer r.Close()
+ defer w.Close()
+
+ result := make(chan error, 1)
+ go func() {
+ var b [1024]byte
+ _, err := r.Read(b[:])
+ result <- err
+ }()
+ if err := w.Close(); err != nil {
+ t.Errorf("w.Close: %v", err)
+ }
+
+ if _, err := w.Write([]byte("hello")); err != io.EOF {
+ t.Errorf("got err %v, want io.EOF after Close", err)
+ }
+
+ if err := <-result; err != io.EOF {
+ t.Errorf("got %v (%T), want io.EOF", err, err)
+ }
+}
+
+func TestMuxCloseWriteChannel(t *testing.T) {
+ r, w, mux := channelPair(t)
+ defer mux.Close()
+
+ result := make(chan error, 1)
+ go func() {
+ var b [1024]byte
+ _, err := r.Read(b[:])
+ result <- err
+ }()
+ if err := w.CloseWrite(); err != nil {
+ t.Errorf("w.CloseWrite: %v", err)
+ }
+
+ if _, err := w.Write([]byte("hello")); err != io.EOF {
+ t.Errorf("got err %v, want io.EOF after CloseWrite", err)
+ }
+
+ if err := <-result; err != io.EOF {
+ t.Errorf("got %v (%T), want io.EOF", err, err)
+ }
+}
+
+func TestMuxInvalidRecord(t *testing.T) {
+ a, b := muxPair()
+ defer a.Close()
+ defer b.Close()
+
+ packet := make([]byte, 1+4+4+1)
+ packet[0] = msgChannelData
+ marshalUint32(packet[1:], 29348723 /* invalid channel id */)
+ marshalUint32(packet[5:], 1)
+ packet[9] = 42
+
+ a.conn.writePacket(packet)
+ go a.SendRequest("hello", false, nil)
+ // 'a' wrote an invalid packet, so 'b' has exited.
+ req, ok := <-b.incomingRequests
+ if ok {
+ t.Errorf("got request %#v after receiving invalid packet", req)
+ }
+}
+
+func TestZeroWindowAdjust(t *testing.T) {
+ a, b, mux := channelPair(t)
+ defer a.Close()
+ defer b.Close()
+ defer mux.Close()
+
+ go func() {
+ io.WriteString(a, "hello")
+ // bogus adjust.
+ a.sendMessage(windowAdjustMsg{})
+ io.WriteString(a, "world")
+ a.Close()
+ }()
+
+ want := "helloworld"
+ c, _ := ioutil.ReadAll(b)
+ if string(c) != want {
+ t.Errorf("got %q want %q", c, want)
+ }
+}
+
+func TestMuxMaxPacketSize(t *testing.T) {
+ a, b, mux := channelPair(t)
+ defer a.Close()
+ defer b.Close()
+ defer mux.Close()
+
+ large := make([]byte, a.maxRemotePayload+1)
+ packet := make([]byte, 1+4+4+1+len(large))
+ packet[0] = msgChannelData
+ marshalUint32(packet[1:], a.remoteId)
+ marshalUint32(packet[5:], uint32(len(large)))
+ packet[9] = 42
+
+ if err := a.mux.conn.writePacket(packet); err != nil {
+ t.Errorf("could not send packet")
+ }
+
+ go a.SendRequest("hello", false, nil)
+
+ _, ok := <-b.incomingRequests
+ if ok {
+ t.Errorf("connection still alive after receiving large packet.")
+ }
+}
+
+// Don't ship code with debug=true.
+func TestDebug(t *testing.T) {
+ if debugMux {
+ t.Error("mux debug switched on")
+ }
+ if debugHandshake {
+ t.Error("handshake debug switched on")
+ }
+}
diff --git a/ssh/server.go b/ssh/server.go
index b4defbe..7a53d57 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -6,38 +6,52 @@
import (
"bytes"
- "crypto/rand"
- "encoding/binary"
"errors"
"fmt"
"io"
"net"
- "sync"
-
- _ "crypto/sha1"
)
-type ServerConfig struct {
- hostKeys []Signer
+// The Permissions type holds fine-grained permissions that are
+// specific to a user or a specific authentication method for a
+// user. Permissions, except for "source-address", must be enforced in
+// the server application layer, after successful authentication. The
+// Permissions are passed on in ServerConn so a server implementation
+// can honor them.
+type Permissions struct {
+ // Critical options restrict default permissions. Common
+ // restrictions are "source-address" and "force-command". If
+ // the server cannot enforce the restriction, or does not
+ // recognize it, the user should not authenticate.
+ CriticalOptions map[string]string
- // Rand provides the source of entropy for key exchange. If Rand is
- // nil, the cryptographic random reader in package crypto/rand will
- // be used.
- Rand io.Reader
+ // Extensions are extra functionality that the server may
+ // offer on authenticated connections. Common extensions are
+ // "permit-agent-forwarding", "permit-X11-forwarding". Lack of
+ // support for an extension does not preclude authenticating a
+ // user.
+ Extensions map[string]string
+}
+
+// ServerConfig holds server specific configuration data.
+type ServerConfig struct {
+ // Config contains configuration shared between client and server.
+ Config
+
+ hostKeys []Signer
// NoClientAuth is true if clients are allowed to connect without
// authenticating.
NoClientAuth bool
- // PasswordCallback, if non-nil, is called when a user attempts to
- // authenticate using a password. It may be called concurrently from
- // several goroutines.
- PasswordCallback func(conn *ServerConn, user, password string) bool
+ // PasswordCallback, if non-nil, is called when a user
+ // attempts to authenticate using a password.
+ PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback, if non-nil, is called when a client attempts public
// key authentication. It must return true if the given public key is
- // valid for the given user.
- PublicKeyCallback func(conn *ServerConn, user, algo string, pubkey []byte) bool
+ // valid for the given user. For example, see CertChecker.Authenticate.
+ PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback, if non-nil, is called when
// keyboard-interactive authentication is selected (RFC
@@ -46,24 +60,19 @@
// Challenge rounds. To avoid information leaks, the client
// should be presented a challenge even if the user is
// unknown.
- KeyboardInteractiveCallback func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool
+ KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
- // Cryptographic-related configuration.
- Crypto CryptoConfig
-}
-
-func (c *ServerConfig) rand() io.Reader {
- if c.Rand == nil {
- return rand.Reader
- }
- return c.Rand
+ // AuthLogCallback, if non-nil, is called to log all authentication
+ // attempts.
+ AuthLogCallback func(conn ConnMetadata, method string, err error)
}
// AddHostKey adds a private key as a host key. If an existing host
-// key exists with the same algorithm, it is overwritten.
+// key exists with the same algorithm, it is overwritten. Each server
+// config must have at least one host key.
func (s *ServerConfig) AddHostKey(key Signer) {
for i, k := range s.hostKeys {
- if k.PublicKey().PublicKeyAlgo() == key.PublicKey().PublicKeyAlgo() {
+ if k.PublicKey().Type() == key.PublicKey().Type() {
s.hostKeys[i] = key
return
}
@@ -72,68 +81,73 @@
s.hostKeys = append(s.hostKeys, key)
}
-// SetRSAPrivateKey sets the private key for a Server. A Server must have a
-// private key configured in order to accept connections. The private key must
-// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
-// typically contains such a key.
-func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error {
- priv, err := ParsePrivateKey(pemBytes)
- if err != nil {
- return err
- }
- s.AddHostKey(priv)
- return nil
+// cachedPubKey contains the results of querying whether a public key is
+// acceptable for a user.
+type cachedPubKey struct {
+ user string
+ pubKeyData []byte
+ result error
+ perms *Permissions
}
-// cachedPubKey contains the results of querying whether a public key is
-// acceptable for a user. The cache only applies to a single ServerConn.
-type cachedPubKey struct {
- user, algo string
- pubKey []byte
- result bool
+func (k1 *cachedPubKey) Equal(k2 *cachedPubKey) bool {
+ return k1.user == k2.user && bytes.Equal(k1.pubKeyData, k2.pubKeyData)
}
const maxCachedPubKeys = 16
-// A ServerConn represents an incoming connection.
-type ServerConn struct {
- transport *transport
- config *ServerConfig
-
- channels map[uint32]*serverChan
- nextChanId uint32
-
- // lock protects err and channels.
- lock sync.Mutex
- err error
-
- // cachedPubKeys contains the cache results of tests for public keys.
- // Since SSH clients will query whether a public key is acceptable
- // before attempting to authenticate with it, we end up with duplicate
- // queries for public key validity.
- cachedPubKeys []cachedPubKey
-
- // User holds the successfully authenticated user name.
- // It is empty if no authentication is used. It is populated before
- // any authentication callback is called and not assigned to after that.
- User string
-
- // ClientVersion is the client's version, populated after
- // Handshake is called. It should not be modified.
- ClientVersion []byte
-
- // Our version.
- serverVersion []byte
+// pubKeyCache caches tests for public keys. Since SSH clients
+// will query whether a public key is acceptable before attempting to
+// authenticate with it, we end up with duplicate queries for public
+// key validity. The cache only applies to a single ServerConn.
+type pubKeyCache struct {
+ keys []cachedPubKey
}
-// Server returns a new SSH server connection
-// using c as the underlying transport.
-func Server(c net.Conn, config *ServerConfig) *ServerConn {
- return &ServerConn{
- transport: newTransport(c, config.rand(), false /* not client */),
- channels: make(map[uint32]*serverChan),
- config: config,
+// get returns the result for a given user/algo/key tuple.
+func (c *pubKeyCache) get(candidate cachedPubKey) (result error, ok bool) {
+ for _, k := range c.keys {
+ if k.Equal(&candidate) {
+ return k.result, true
+ }
}
+ return errors.New("ssh: not in cache"), false
+}
+
+// add adds the given tuple to the cache.
+func (c *pubKeyCache) add(candidate cachedPubKey) {
+ if len(c.keys) < maxCachedPubKeys {
+ c.keys = append(c.keys, candidate)
+ }
+}
+
+// ServerConn is an authenticated SSH connection, as seen from the
+// server
+type ServerConn struct {
+ Conn
+
+ // If the succeeding authentication callback returned a
+ // non-nil Permissions pointer, it is stored here.
+ Permissions *Permissions
+}
+
+// NewServerConn starts a new SSH server with c as the underlying
+// transport. It starts with a handshake and, if the handshake is
+// unsuccessful, it closes the connection and returns an error. The
+// Request and NewChannel channels must be serviced, or the connection
+// will hang.
+func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
+ fullConf := *config
+ fullConf.SetDefaults()
+ s := &connection{
+ sshConn: sshConn{conn: c},
+ }
+ perms, err := s.serverHandshake(&fullConf)
+ if err != nil {
+ c.Close()
+ return nil, nil, nil, err
+ }
+ return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
}
// signAndMarshal signs the data with the appropriate algorithm,
@@ -144,134 +158,60 @@
return nil, err
}
- return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
+ return Marshal(sig), nil
}
-// Close closes the connection.
-func (s *ServerConn) Close() error { return s.transport.Close() }
+// handshake performs key exchange and user authentication.
+func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
+ if len(config.hostKeys) == 0 {
+ return nil, errors.New("ssh: server has no host keys")
+ }
-// LocalAddr returns the local network address.
-func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
-
-// RemoteAddr returns the remote network address.
-func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
-
-// Handshake performs an SSH transport and client authentication on the given ServerConn.
-func (s *ServerConn) Handshake() error {
var err error
s.serverVersion = []byte(packageVersion)
- s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersion)
+ s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
if err != nil {
- return err
+ return nil, err
}
- if err := s.clientInitHandshake(nil, nil); err != nil {
- return err
+
+ tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
+ s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
+
+ if err := s.transport.requestKeyChange(); err != nil {
+ return nil, err
+ }
+
+ if packet, err := s.transport.readPacket(); err != nil {
+ return nil, err
+ } else if packet[0] != msgNewKeys {
+ return nil, unexpectedMessageError(msgNewKeys, packet[0])
}
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
- return err
+ return nil, err
}
+
var serviceRequest serviceRequestMsg
- if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
- return err
+ if err = Unmarshal(packet, &serviceRequest); err != nil {
+ return nil, err
}
if serviceRequest.Service != serviceUserAuth {
- return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
+ return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
}
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
- if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
- return err
+ if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
+ return nil, err
}
- if err := s.authenticate(); err != nil {
- return err
- }
- return err
-}
-
-func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
- serverKexInit := kexInitMsg{
- KexAlgos: s.config.Crypto.kexes(),
- CiphersClientServer: s.config.Crypto.ciphers(),
- CiphersServerClient: s.config.Crypto.ciphers(),
- MACsClientServer: s.config.Crypto.macs(),
- MACsServerClient: s.config.Crypto.macs(),
- CompressionClientServer: supportedCompressions,
- CompressionServerClient: supportedCompressions,
- }
- for _, k := range s.config.hostKeys {
- serverKexInit.ServerHostKeyAlgos = append(
- serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKeyAlgo())
- }
-
- serverKexInitPacket := marshal(msgKexInit, serverKexInit)
- if err = s.transport.writePacket(serverKexInitPacket); err != nil {
- return
- }
-
- if clientKexInitPacket == nil {
- clientKexInit = new(kexInitMsg)
- if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
- return
- }
- if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
- return
- }
- }
-
- algs := findAgreedAlgorithms(clientKexInit, &serverKexInit)
- if algs == nil {
- return errors.New("ssh: no common algorithms")
- }
-
- if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
- // The client sent a Kex message for the wrong algorithm,
- // which we have to ignore.
- if _, err = s.transport.readPacket(); err != nil {
- return
- }
- }
-
- var hostKey Signer
- for _, k := range s.config.hostKeys {
- if algs.hostKey == k.PublicKey().PublicKeyAlgo() {
- hostKey = k
- }
- }
-
- kex, ok := kexAlgoMap[algs.kex]
- if !ok {
- return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
- }
-
- magics := handshakeMagics{
- serverVersion: s.serverVersion,
- clientVersion: s.ClientVersion,
- serverKexInit: marshal(msgKexInit, serverKexInit),
- clientKexInit: clientKexInitPacket,
- }
- result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
+ perms, err := s.serverAuthenticate(config)
if err != nil {
- return err
+ return nil, err
}
-
- if err = s.transport.prepareKeyChange(algs, result); err != nil {
- return err
- }
-
- if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
- return
- }
- if packet, err := s.transport.readPacket(); err != nil {
- return err
- } else if packet[0] != msgNewKeys {
- return UnexpectedMessageError{msgNewKeys, packet[0]}
- }
-
- return
+ s.mux = newMux(s.transport)
+ return perms, err
}
func isAcceptableAlgo(algo string) bool {
@@ -283,181 +223,213 @@
return false
}
-// testPubKey returns true if the given public key is acceptable for the user.
-func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
- if s.config.PublicKeyCallback == nil || !isAcceptableAlgo(algo) {
- return false
+func checkSourceAddress(addr net.Addr, sourceAddr string) error {
+ if addr == nil {
+ return errors.New("ssh: no address known for client, but source-address match required")
}
- for _, c := range s.cachedPubKeys {
- if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
- return c.result
+ tcpAddr, ok := addr.(*net.TCPAddr)
+ if !ok {
+ return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
+ }
+
+ if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
+ if bytes.Equal(allowedIP, tcpAddr.IP) {
+ return nil
+ }
+ } else {
+ _, ipNet, err := net.ParseCIDR(sourceAddr)
+ if err != nil {
+ return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
+ }
+
+ if ipNet.Contains(tcpAddr.IP) {
+ return nil
}
}
- result := s.config.PublicKeyCallback(s, user, algo, pubKey)
- if len(s.cachedPubKeys) < maxCachedPubKeys {
- c := cachedPubKey{
- user: user,
- algo: algo,
- pubKey: make([]byte, len(pubKey)),
- result: result,
- }
- copy(c.pubKey, pubKey)
- s.cachedPubKeys = append(s.cachedPubKeys, c)
- }
-
- return result
+ return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
-func (s *ServerConn) authenticate() error {
- var userAuthReq userAuthRequestMsg
+func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
var err error
- var packet []byte
+ var cache pubKeyCache
+ var perms *Permissions
userAuthLoop:
for {
- if packet, err = s.transport.readPacket(); err != nil {
- return err
- }
- if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
- return err
+ var userAuthReq userAuthRequestMsg
+ if packet, err := s.transport.readPacket(); err != nil {
+ return nil, err
+ } else if err = Unmarshal(packet, &userAuthReq); err != nil {
+ return nil, err
}
if userAuthReq.Service != serviceSSH {
- return errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
+ return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
}
+ s.user = userAuthReq.User
+ perms = nil
+ authErr := errors.New("no auth passed yet")
+
switch userAuthReq.Method {
case "none":
- if s.config.NoClientAuth {
- break userAuthLoop
+ if config.NoClientAuth {
+ s.user = ""
+ authErr = nil
}
case "password":
- if s.config.PasswordCallback == nil {
+ if config.PasswordCallback == nil {
+ authErr = errors.New("ssh: password auth not configured")
break
}
payload := userAuthReq.Payload
if len(payload) < 1 || payload[0] != 0 {
- return ParseError{msgUserAuthRequest}
+ return nil, parseError(msgUserAuthRequest)
}
payload = payload[1:]
password, payload, ok := parseString(payload)
if !ok || len(payload) > 0 {
- return ParseError{msgUserAuthRequest}
+ return nil, parseError(msgUserAuthRequest)
}
- s.User = userAuthReq.User
- if s.config.PasswordCallback(s, userAuthReq.User, string(password)) {
- break userAuthLoop
- }
+ perms, authErr = config.PasswordCallback(s, password)
case "keyboard-interactive":
- if s.config.KeyboardInteractiveCallback == nil {
+ if config.KeyboardInteractiveCallback == nil {
+ authErr = errors.New("ssh: keyboard-interactive auth not configubred")
break
}
- s.User = userAuthReq.User
- if s.config.KeyboardInteractiveCallback(s, s.User, &sshClientKeyboardInteractive{s}) {
- break userAuthLoop
- }
+ prompter := &sshClientKeyboardInteractive{s}
+ perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
case "publickey":
- if s.config.PublicKeyCallback == nil {
+ if config.PublicKeyCallback == nil {
+ authErr = errors.New("ssh: publickey auth not configured")
break
}
payload := userAuthReq.Payload
if len(payload) < 1 {
- return ParseError{msgUserAuthRequest}
+ return nil, parseError(msgUserAuthRequest)
}
isQuery := payload[0] == 0
payload = payload[1:]
algoBytes, payload, ok := parseString(payload)
if !ok {
- return ParseError{msgUserAuthRequest}
+ return nil, parseError(msgUserAuthRequest)
}
algo := string(algoBytes)
-
- pubKey, payload, ok := parseString(payload)
- if !ok {
- return ParseError{msgUserAuthRequest}
+ if !isAcceptableAlgo(algo) {
+ authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
+ break
}
+
+ pubKeyData, payload, ok := parseString(payload)
+ if !ok {
+ return nil, parseError(msgUserAuthRequest)
+ }
+
+ pubKey, err := ParsePublicKey(pubKeyData)
+ if err != nil {
+ return nil, err
+ }
+ candidate := cachedPubKey{
+ user: s.user,
+ pubKeyData: pubKeyData,
+ }
+ candidate.result, ok = cache.get(candidate)
+ if !ok {
+ candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
+ if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
+ candidate.result = checkSourceAddress(
+ s.RemoteAddr(),
+ candidate.perms.CriticalOptions[sourceAddressCriticalOption])
+ }
+ cache.add(candidate)
+ }
+
if isQuery {
// The client can query if the given public key
// would be okay.
if len(payload) > 0 {
- return ParseError{msgUserAuthRequest}
+ return nil, parseError(msgUserAuthRequest)
}
- if s.testPubKey(userAuthReq.User, algo, pubKey) {
+
+ if candidate.result == nil {
okMsg := userAuthPubKeyOkMsg{
Algo: algo,
- PubKey: string(pubKey),
+ PubKey: pubKeyData,
}
- if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
- return err
+ if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
+ return nil, err
}
continue userAuthLoop
}
+ authErr = candidate.result
} else {
sig, payload, ok := parseSignature(payload)
if !ok || len(payload) > 0 {
- return ParseError{msgUserAuthRequest}
+ return nil, parseError(msgUserAuthRequest)
}
// Ensure the public key algo and signature algo
// are supported. Compare the private key
// algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but
// for certs, the names differ.
- if !isAcceptableAlgo(algo) || !isAcceptableAlgo(sig.Format) || pubAlgoToPrivAlgo(algo) != sig.Format {
+ if !isAcceptableAlgo(sig.Format) {
break
}
- signedData := buildDataSignedForAuth(s.transport.sessionID, userAuthReq, algoBytes, pubKey)
- key, _, ok := ParsePublicKey(pubKey)
- if !ok {
- return ParseError{msgUserAuthRequest}
+ signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
+
+ if err := pubKey.Verify(signedData, sig); err != nil {
+ return nil, err
}
- if !key.Verify(signedData, sig.Blob) {
- return ParseError{msgUserAuthRequest}
- }
- // TODO(jmpittman): Implement full validation for certificates.
- s.User = userAuthReq.User
- if s.testPubKey(userAuthReq.User, algo, pubKey) {
- break userAuthLoop
- }
+ authErr = candidate.result
+ perms = candidate.perms
}
+ default:
+ authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
+ }
+
+ if config.AuthLogCallback != nil {
+ config.AuthLogCallback(s, userAuthReq.Method, authErr)
+ }
+
+ if authErr == nil {
+ break userAuthLoop
}
var failureMsg userAuthFailureMsg
- if s.config.PasswordCallback != nil {
+ if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
}
- if s.config.PublicKeyCallback != nil {
+ if config.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey")
}
- if s.config.KeyboardInteractiveCallback != nil {
+ if config.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
}
if len(failureMsg.Methods) == 0 {
- return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
+ return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
- if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
- return err
+ if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
+ return nil, err
}
}
- packet = []byte{msgUserAuthSuccess}
- if err = s.transport.writePacket(packet); err != nil {
- return err
+ if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
+ return nil, err
}
-
- return nil
+ return perms, nil
}
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
// asking the client on the other side of a ServerConn.
type sshClientKeyboardInteractive struct {
- *ServerConn
+ *connection
}
func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
@@ -471,7 +443,7 @@
prompts = appendBool(prompts, echos[i])
}
- if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
+ if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
@@ -484,19 +456,19 @@
return nil, err
}
if packet[0] != msgUserAuthInfoResponse {
- return nil, UnexpectedMessageError{msgUserAuthInfoResponse, packet[0]}
+ return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
}
packet = packet[1:]
n, packet, ok := parseUint32(packet)
if !ok || int(n) != len(questions) {
- return nil, &ParseError{msgUserAuthInfoResponse}
+ return nil, parseError(msgUserAuthInfoResponse)
}
for i := uint32(0); i < n; i++ {
ans, rest, ok := parseString(packet)
if !ok {
- return nil, &ParseError{msgUserAuthInfoResponse}
+ return nil, parseError(msgUserAuthInfoResponse)
}
answers = append(answers, string(ans))
@@ -508,185 +480,3 @@
return answers, nil
}
-
-const defaultWindowSize = 32768
-
-// Accept reads and processes messages on a ServerConn. It must be called
-// in order to demultiplex messages to any resulting Channels.
-func (s *ServerConn) Accept() (Channel, error) {
- // TODO(dfc) s.lock is not held here so visibility of s.err is not guaranteed.
- if s.err != nil {
- return nil, s.err
- }
-
- for {
- packet, err := s.transport.readPacket()
- if err != nil {
-
- s.lock.Lock()
- s.err = err
- s.lock.Unlock()
-
- // TODO(dfc) s.lock protects s.channels but isn't being held here.
- for _, c := range s.channels {
- c.setDead()
- c.handleData(nil)
- }
-
- return nil, err
- }
-
- switch packet[0] {
- case msgChannelData:
- if len(packet) < 9 {
- // malformed data packet
- return nil, ParseError{msgChannelData}
- }
- remoteId := binary.BigEndian.Uint32(packet[1:5])
- s.lock.Lock()
- c, ok := s.channels[remoteId]
- if !ok {
- s.lock.Unlock()
- continue
- }
- if length := binary.BigEndian.Uint32(packet[5:9]); length > 0 {
- packet = packet[9:]
- c.handleData(packet[:length])
- }
- s.lock.Unlock()
- default:
- decoded, err := decode(packet)
- if err != nil {
- return nil, err
- }
- switch msg := decoded.(type) {
- case *channelOpenMsg:
- if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
- return nil, errors.New("ssh: invalid MaxPacketSize from peer")
- }
- c := &serverChan{
- channel: channel{
- packetConn: s.transport,
- remoteId: msg.PeersId,
- remoteWin: window{Cond: newCond()},
- maxPacket: msg.MaxPacketSize,
- },
- chanType: msg.ChanType,
- extraData: msg.TypeSpecificData,
- myWindow: defaultWindowSize,
- serverConn: s,
- cond: newCond(),
- pendingData: make([]byte, defaultWindowSize),
- }
- c.remoteWin.add(msg.PeersWindow)
- s.lock.Lock()
- c.localId = s.nextChanId
- s.nextChanId++
- s.channels[c.localId] = c
- s.lock.Unlock()
- return c, nil
-
- case *channelRequestMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- s.lock.Unlock()
- continue
- }
- c.handlePacket(msg)
- s.lock.Unlock()
-
- case *windowAdjustMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- s.lock.Unlock()
- continue
- }
- c.handlePacket(msg)
- s.lock.Unlock()
-
- case *channelEOFMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- s.lock.Unlock()
- continue
- }
- c.handlePacket(msg)
- s.lock.Unlock()
-
- case *channelCloseMsg:
- s.lock.Lock()
- c, ok := s.channels[msg.PeersId]
- if !ok {
- s.lock.Unlock()
- continue
- }
- c.handlePacket(msg)
- s.lock.Unlock()
-
- case *globalRequestMsg:
- if msg.WantReply {
- if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
- return nil, err
- }
- }
-
- case *kexInitMsg:
- s.lock.Lock()
- if err := s.clientInitHandshake(msg, packet); err != nil {
- s.lock.Unlock()
- return nil, err
- }
- s.lock.Unlock()
- case *disconnectMsg:
- return nil, io.EOF
- default:
- // Unknown message. Ignore.
- }
- }
- }
-
- panic("unreachable")
-}
-
-// A Listener implements a network listener (net.Listener) for SSH connections.
-type Listener struct {
- listener net.Listener
- config *ServerConfig
-}
-
-// Addr returns the listener's network address.
-func (l *Listener) Addr() net.Addr {
- return l.listener.Addr()
-}
-
-// Close closes the listener.
-func (l *Listener) Close() error {
- return l.listener.Close()
-}
-
-// Accept waits for and returns the next incoming SSH connection.
-// The receiver should call Handshake() in another goroutine
-// to avoid blocking the accepter.
-func (l *Listener) Accept() (*ServerConn, error) {
- c, err := l.listener.Accept()
- if err != nil {
- return nil, err
- }
- return Server(c, l.config), nil
-}
-
-// Listen creates an SSH listener accepting connections on
-// the given network address using net.Listen.
-func Listen(network, addr string, config *ServerConfig) (*Listener, error) {
- l, err := net.Listen(network, addr)
- if err != nil {
- return nil, err
- }
- return &Listener{
- l,
- config,
- }, nil
-}
diff --git a/ssh/session.go b/ssh/session.go
index 39f2d22..3b42b50 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -129,128 +129,126 @@
Stdout io.Writer
Stderr io.Writer
- *clientChan // the channel backing this session
-
- started bool // true once Start, Run or Shell is invoked.
+ ch Channel // the channel backing this session
+ started bool // true once Start, Run or Shell is invoked.
copyFuncs []func() error
errors chan error // one send per copyFunc
// true if pipe method is active
stdinpipe, stdoutpipe, stderrpipe bool
+
+ // stdinPipeWriter is non-nil if StdinPipe has not been called
+ // and Stdin was specified by the user; it is the write end of
+ // a pipe connecting Session.Stdin to the stdin channel.
+ stdinPipeWriter io.WriteCloser
+
+ exitStatus chan error
+}
+
+// SendRequest sends an out-of-band channel request on the SSH channel
+// underlying the session.
+func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ return s.ch.SendRequest(name, wantReply, payload)
+}
+
+func (s *Session) Close() error {
+ return s.ch.Close()
}
// RFC 4254 Section 6.4.
type setenvRequest struct {
- PeersId uint32
- Request string
- WantReply bool
- Name string
- Value string
-}
-
-// RFC 4254 Section 6.5.
-type subsystemRequestMsg struct {
- PeersId uint32
- Request string
- WantReply bool
- Subsystem string
+ Name string
+ Value string
}
// Setenv sets an environment variable that will be applied to any
// command executed by Shell or Run.
func (s *Session) Setenv(name, value string) error {
- req := setenvRequest{
- PeersId: s.remoteId,
- Request: "env",
- WantReply: true,
- Name: name,
- Value: value,
+ msg := setenvRequest{
+ Name: name,
+ Value: value,
}
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
- return err
+ ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
+ if err == nil && !ok {
+ err = errors.New("ssh: setenv failed")
}
- return s.waitForResponse()
+ return err
}
// RFC 4254 Section 6.2.
type ptyRequestMsg struct {
- PeersId uint32
- Request string
- WantReply bool
- Term string
- Columns uint32
- Rows uint32
- Width uint32
- Height uint32
- Modelist string
+ Term string
+ Columns uint32
+ Rows uint32
+ Width uint32
+ Height uint32
+ Modelist string
}
// RequestPty requests the association of a pty with the session on the remote host.
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
var tm []byte
for k, v := range termmodes {
- tm = append(tm, k)
- tm = appendU32(tm, v)
+ kv := struct {
+ Key byte
+ Val uint32
+ }{k, v}
+
+ tm = append(tm, Marshal(&kv)...)
}
tm = append(tm, tty_OP_END)
req := ptyRequestMsg{
- PeersId: s.remoteId,
- Request: "pty-req",
- WantReply: true,
- Term: term,
- Columns: uint32(w),
- Rows: uint32(h),
- Width: uint32(w * 8),
- Height: uint32(h * 8),
- Modelist: string(tm),
+ Term: term,
+ Columns: uint32(w),
+ Rows: uint32(h),
+ Width: uint32(w * 8),
+ Height: uint32(h * 8),
+ Modelist: string(tm),
}
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
- return err
+ ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
+ if err == nil && !ok {
+ err = errors.New("ssh: pty-req failed")
}
- return s.waitForResponse()
+ return err
+}
+
+// RFC 4254 Section 6.5.
+type subsystemRequestMsg struct {
+ Subsystem string
}
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
func (s *Session) RequestSubsystem(subsystem string) error {
- req := subsystemRequestMsg{
- PeersId: s.remoteId,
- Request: "subsystem",
- WantReply: true,
+ msg := subsystemRequestMsg{
Subsystem: subsystem,
}
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
- return err
+ ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
+ if err == nil && !ok {
+ err = errors.New("ssh: subsystem request failed")
}
- return s.waitForResponse()
+ return err
}
// RFC 4254 Section 6.9.
type signalMsg struct {
- PeersId uint32
- Request string
- WantReply bool
- Signal string
+ Signal string
}
// Signal sends the given signal to the remote process.
// sig is one of the SIG* constants.
func (s *Session) Signal(sig Signal) error {
- req := signalMsg{
- PeersId: s.remoteId,
- Request: "signal",
- WantReply: false,
- Signal: string(sig),
+ msg := signalMsg{
+ Signal: string(sig),
}
- return s.writePacket(marshal(msgChannelRequest, req))
+
+ _, err := s.ch.SendRequest("signal", false, Marshal(&msg))
+ return err
}
// RFC 4254 Section 6.5.
type execMsg struct {
- PeersId uint32
- Request string
- WantReply bool
- Command string
+ Command string
}
// Start runs cmd on the remote host. Typically, the remote
@@ -261,17 +259,16 @@
return errors.New("ssh: session already started")
}
req := execMsg{
- PeersId: s.remoteId,
- Request: "exec",
- WantReply: true,
- Command: cmd,
+ Command: cmd,
}
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+
+ ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
+ if err == nil && !ok {
+ err = fmt.Errorf("ssh: command %v failed", cmd)
+ }
+ if err != nil {
return err
}
- if err := s.waitForResponse(); err != nil {
- return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err)
- }
return s.start()
}
@@ -339,31 +336,17 @@
if s.started {
return errors.New("ssh: session already started")
}
- req := channelRequestMsg{
- PeersId: s.remoteId,
- Request: "shell",
- WantReply: true,
+
+ ok, err := s.ch.SendRequest("shell", true, nil)
+ if err == nil && !ok {
+ return fmt.Errorf("ssh: cound not start shell")
}
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil {
+ if err != nil {
return err
}
- if err := s.waitForResponse(); err != nil {
- return fmt.Errorf("ssh: could not execute shell: %v", err)
- }
return s.start()
}
-func (s *Session) waitForResponse() error {
- msg := <-s.msg
- switch msg.(type) {
- case *channelRequestSuccessMsg:
- return nil
- case *channelRequestFailureMsg:
- return errors.New("ssh: request failed")
- }
- return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg)
-}
-
func (s *Session) start() error {
s.started = true
@@ -394,8 +377,11 @@
if !s.started {
return errors.New("ssh: session not started")
}
- waitErr := s.wait()
+ waitErr := <-s.exitStatus
+ if s.stdinPipeWriter != nil {
+ s.stdinPipeWriter.Close()
+ }
var copyError error
for _ = range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil {
@@ -408,52 +394,35 @@
return copyError
}
-func (s *Session) wait() error {
+func (s *Session) wait(reqs <-chan *Request) error {
wm := Waitmsg{status: -1}
-
// Wait for msg channel to be closed before returning.
- for msg := range s.msg {
- switch msg := msg.(type) {
- case *channelRequestMsg:
- switch msg.Request {
- case "exit-status":
- d := msg.RequestSpecificData
- wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
- case "exit-signal":
- signal, rest, ok := parseString(msg.RequestSpecificData)
- if !ok {
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
- }
- wm.signal = safeString(string(signal))
-
- // skip coreDumped bool
- if len(rest) == 0 {
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
- }
- rest = rest[1:]
-
- errmsg, rest, ok := parseString(rest)
- if !ok {
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
- }
- wm.msg = safeString(string(errmsg))
-
- lang, _, ok := parseString(rest)
- if !ok {
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData)
- }
- wm.lang = safeString(string(lang))
- default:
- // This handles keepalives and matches
- // OpenSSH's behaviour.
- if msg.WantReply {
- s.writePacket(marshal(msgChannelFailure, channelRequestFailureMsg{
- PeersId: s.remoteId,
- }))
- }
+ for msg := range reqs {
+ switch msg.Type {
+ case "exit-status":
+ d := msg.Payload
+ wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
+ case "exit-signal":
+ var sigval struct {
+ Signal string
+ CoreDumped bool
+ Error string
+ Lang string
}
+ if err := Unmarshal(msg.Payload, &sigval); err != nil {
+ return err
+ }
+
+ // Must sanitize strings?
+ wm.signal = sigval.Signal
+ wm.msg = sigval.Error
+ wm.lang = sigval.Lang
default:
- return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg)
+ // This handles keepalives and matches
+ // OpenSSH's behaviour.
+ if msg.WantReply {
+ msg.Reply(false, nil)
+ }
}
}
if wm.status == 0 {
@@ -476,12 +445,20 @@
if s.stdinpipe {
return
}
+ var stdin io.Reader
if s.Stdin == nil {
- s.Stdin = new(bytes.Buffer)
+ stdin = new(bytes.Buffer)
+ } else {
+ r, w := io.Pipe()
+ go func() {
+ _, err := io.Copy(w, s.Stdin)
+ w.CloseWithError(err)
+ }()
+ stdin, s.stdinPipeWriter = r, w
}
s.copyFuncs = append(s.copyFuncs, func() error {
- _, err := io.Copy(s.clientChan.stdin, s.Stdin)
- if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF {
+ _, err := io.Copy(s.ch, stdin)
+ if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
err = err1
}
return err
@@ -496,7 +473,7 @@
s.Stdout = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
- _, err := io.Copy(s.Stdout, s.clientChan.stdout)
+ _, err := io.Copy(s.Stdout, s.ch)
return err
})
}
@@ -509,11 +486,21 @@
s.Stderr = ioutil.Discard
}
s.copyFuncs = append(s.copyFuncs, func() error {
- _, err := io.Copy(s.Stderr, s.clientChan.stderr)
+ _, err := io.Copy(s.Stderr, s.ch.Stderr())
return err
})
}
+// sessionStdin reroutes Close to CloseWrite.
+type sessionStdin struct {
+ io.Writer
+ ch Channel
+}
+
+func (s *sessionStdin) Close() error {
+ return s.ch.CloseWrite()
+}
+
// StdinPipe returns a pipe that will be connected to the
// remote command's standard input when the command starts.
func (s *Session) StdinPipe() (io.WriteCloser, error) {
@@ -524,7 +511,7 @@
return nil, errors.New("ssh: StdinPipe after process started")
}
s.stdinpipe = true
- return s.clientChan.stdin, nil
+ return &sessionStdin{s.ch, s.ch}, nil
}
// StdoutPipe returns a pipe that will be connected to the
@@ -541,7 +528,7 @@
return nil, errors.New("ssh: StdoutPipe after process started")
}
s.stdoutpipe = true
- return s.clientChan.stdout, nil
+ return s.ch, nil
}
// StderrPipe returns a pipe that will be connected to the
@@ -558,28 +545,20 @@
return nil, errors.New("ssh: StderrPipe after process started")
}
s.stderrpipe = true
- return s.clientChan.stderr, nil
+ return s.ch.Stderr(), nil
}
-// NewSession returns a new interactive session on the remote host.
-func (c *ClientConn) NewSession() (*Session, error) {
- ch := c.newChan(c.transport)
- if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{
- ChanType: "session",
- PeersId: ch.localId,
- PeersWindow: channelWindowSize,
- MaxPacketSize: channelMaxPacketSize,
- })); err != nil {
- c.chanList.remove(ch.localId)
- return nil, err
+// newSession returns a new interactive session on the remote host.
+func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
+ s := &Session{
+ ch: ch,
}
- if err := ch.waitForChannelOpenResponse(); err != nil {
- c.chanList.remove(ch.localId)
- return nil, fmt.Errorf("ssh: unable to open session: %v", err)
- }
- return &Session{
- clientChan: ch,
- }, nil
+ s.exitStatus = make(chan error, 1)
+ go func() {
+ s.exitStatus <- s.wait(reqs)
+ }()
+
+ return s, nil
}
// An ExitError reports unsuccessful completion of a remote command.
diff --git a/ssh/session_test.go b/ssh/session_test.go
index 5cff58a..cc26573 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -12,71 +12,60 @@
"io"
"io/ioutil"
"math/rand"
- "net"
"testing"
"code.google.com/p/go.crypto/ssh/terminal"
)
-type serverType func(*serverChan, *testing.T)
+type serverType func(Channel, <-chan *Request, *testing.T)
// dial constructs a new test server and returns a *ClientConn.
-func dial(handler serverType, t *testing.T) *ClientConn {
- l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+func dial(handler serverType, t *testing.T) *Client {
+ c1, c2, err := netPipe()
if err != nil {
- t.Fatalf("unable to listen: %v", err)
+ t.Fatalf("netPipe: %v", err)
}
+
go func() {
- defer l.Close()
- conn, err := l.Accept()
+ defer c1.Close()
+ conf := ServerConfig{
+ NoClientAuth: true,
+ }
+ conf.AddHostKey(testSigners["rsa"])
+
+ _, chans, reqs, err := NewServerConn(c1, &conf)
if err != nil {
- t.Errorf("Unable to accept: %v", err)
- return
+ t.Fatalf("Unable to handshake: %v", err)
}
- defer conn.Close()
- if err := conn.Handshake(); err != nil {
- t.Errorf("Unable to handshake: %v", err)
- return
- }
- done := make(chan struct{})
- for {
- ch, err := conn.Accept()
- if err == io.EOF || err == io.ErrUnexpectedEOF {
- return
- }
- // We sometimes get ECONNRESET rather than EOF.
- if _, ok := err.(*net.OpError); ok {
- return
- }
- if err != nil {
- t.Errorf("Unable to accept incoming channel request: %v", err)
- return
- }
- if ch.ChannelType() != "session" {
- ch.Reject(UnknownChannelType, "unknown channel type")
+ go DiscardRequests(reqs)
+
+ for newCh := range chans {
+ if newCh.ChannelType() != "session" {
+ newCh.Reject(UnknownChannelType, "unknown channel type")
continue
}
- ch.Accept()
+
+ ch, inReqs, err := newCh.Accept()
+ if err != nil {
+ t.Errorf("Accept: %v", err)
+ continue
+ }
go func() {
- defer close(done)
- handler(ch.(*serverChan), t)
+ handler(ch, inReqs, t)
}()
}
- <-done
}()
config := &ClientConfig{
User: "testuser",
- Auth: []ClientAuth{
- ClientAuthPassword(clientPassword),
- },
}
- c, err := Dial("tcp", l.Addr().String(), config)
+ conn, chans, reqs, err := NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("unable to dial remote side: %v", err)
}
- return c
+
+ return NewClient(conn, chans, reqs)
}
// Test a simple string is returned to session.Stdout.
@@ -330,164 +319,6 @@
}
}
-func TestInvalidServerMessage(t *testing.T) {
- conn := dial(sendInvalidRecord, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
- }
- // Make sure that we closed all the clientChans when the connection
- // failed.
- session.wait()
-
- defer session.Close()
-}
-
-// In the wild some clients (and servers) send zero sized window updates.
-// Test that the client can continue after receiving a zero sized update.
-func TestClientZeroWindowAdjust(t *testing.T) {
- conn := dial(sendZeroWindowAdjust, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
- }
- defer session.Close()
-
- if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %v", err)
- }
- err = session.Wait()
- if err != nil {
- t.Fatalf("expected nil but got %v", err)
- }
-}
-
-// In the wild some clients (and servers) send zero sized window updates.
-// Test that the server can continue after receiving a zero size update.
-func TestServerZeroWindowAdjust(t *testing.T) {
- conn := dial(exitStatusZeroHandler, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
- }
- defer session.Close()
-
- if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %v", err)
- }
-
- // send a bogus zero sized window update
- session.clientChan.sendWindowAdj(0)
-
- err = session.Wait()
- if err != nil {
- t.Fatalf("expected nil but got %v", err)
- }
-}
-
-// Verify that the client never sends a packet larger than maxpacket.
-func TestClientStdinRespectsMaxPacketSize(t *testing.T) {
- conn := dial(discardHandler, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("failed to request new session: %v", err)
- }
- defer session.Close()
- stdin, err := session.StdinPipe()
- if err != nil {
- t.Fatalf("failed to obtain stdinpipe: %v", err)
- }
- const size = 100 * 1000
- for i := 0; i < 10; i++ {
- n, err := stdin.Write(make([]byte, size))
- if n != size || err != nil {
- t.Fatalf("failed to write: %d, %v", n, err)
- }
- }
-}
-
-// Verify that the client never accepts a packet larger than maxpacket.
-func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
- conn := dial(largeSendHandler, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
- }
- defer session.Close()
- out, err := session.StdoutPipe()
- if err != nil {
- t.Fatalf("Unable to connect to Stdout: %v", err)
- }
- if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %v", err)
- }
- if _, err := ioutil.ReadAll(out); err != nil {
- t.Fatalf("failed to read: %v", err)
- }
-}
-
-func TestClientCannotSendAfterEOF(t *testing.T) {
- conn := dial(exitWithoutSignalOrStatus, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
- }
- defer session.Close()
- in, err := session.StdinPipe()
- if err != nil {
- t.Fatalf("Unable to connect channel stdin: %v", err)
- }
- if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %v", err)
- }
- if err := in.Close(); err != nil {
- t.Fatalf("Unable to close stdin: %v", err)
- }
- if _, err := in.Write([]byte("foo")); err == nil {
- t.Fatalf("Session write should fail")
- }
-}
-
-func TestClientCannotSendAfterClose(t *testing.T) {
- conn := dial(exitWithoutSignalOrStatus, t)
- defer conn.Close()
- session, err := conn.NewSession()
- if err != nil {
- t.Fatalf("Unable to request new session: %v", err)
- }
- defer session.Close()
- in, err := session.StdinPipe()
- if err != nil {
- t.Fatalf("Unable to connect channel stdin: %v", err)
- }
- if err := session.Shell(); err != nil {
- t.Fatalf("Unable to execute command: %v", err)
- }
- // close underlying channel
- if err := session.channel.Close(); err != nil {
- t.Fatalf("Unable to close session: %v", err)
- }
- if _, err := in.Write([]byte("foo")); err == nil {
- t.Fatalf("Session write should fail")
- }
-}
-
-func TestClientCannotSendHugePacket(t *testing.T) {
- // client and server use the same transport write code so this
- // test suffices for both.
- conn := dial(shellHandler, t)
- defer conn.Close()
- if err := conn.transport.writePacket(make([]byte, maxPacket*2)); err == nil {
- t.Fatalf("huge packet write should fail")
- }
-}
-
// windowTestBytes is the number of bytes that we'll send to the SSH server.
const windowTestBytes = 16000 * 200
@@ -560,93 +391,104 @@
}
type exitStatusMsg struct {
- PeersId uint32
- Request string
- WantReply bool
- Status uint32
+ Status uint32
}
type exitSignalMsg struct {
- PeersId uint32
- Request string
- WantReply bool
Signal string
CoreDumped bool
Errmsg string
Lang string
}
-func newServerShell(ch *serverChan, prompt string) *ServerTerminal {
- term := terminal.NewTerminal(ch, prompt)
- return &ServerTerminal{
- Term: term,
- Channel: ch,
+func handleTerminalRequests(in <-chan *Request) {
+ for req := range in {
+ ok := false
+ switch req.Type {
+ case "shell":
+ ok = true
+ if len(req.Payload) > 0 {
+ // We don't accept any commands, only the default shell.
+ ok = false
+ }
+ case "env":
+ ok = true
+ }
+ req.Reply(ok, nil)
}
}
-func exitStatusZeroHandler(ch *serverChan, t *testing.T) {
+func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
+ term := terminal.NewTerminal(ch, prompt)
+ go handleTerminalRequests(in)
+ return term
+}
+
+func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(0, ch, t)
}
-func exitStatusNonZeroHandler(ch *serverChan, t *testing.T) {
+func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
}
-func exitSignalAndStatusHandler(ch *serverChan, t *testing.T) {
+func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendStatus(15, ch, t)
sendSignal("TERM", ch, t)
}
-func exitSignalHandler(ch *serverChan, t *testing.T) {
+func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("TERM", ch, t)
}
-func exitSignalUnknownHandler(ch *serverChan, t *testing.T) {
+func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
sendSignal("SYS", ch, t)
}
-func exitWithoutSignalOrStatus(ch *serverChan, t *testing.T) {
+func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
}
-func shellHandler(ch *serverChan, t *testing.T) {
+func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
// this string is returned to stdout
- shell := newServerShell(ch, "golang")
+ shell := newServerShell(ch, in, "golang")
readLine(shell, t)
sendStatus(0, ch, t)
}
// Ignores the command, writes fixed strings to stderr and stdout.
// Strings are "this-is-stdout." and "this-is-stderr.".
-func fixedOutputHandler(ch *serverChan, t *testing.T) {
+func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
+ _, err := ch.Read(nil)
- _, err := ch.Read(make([]byte, 0))
- if _, ok := err.(ChannelRequest); !ok {
+ req, ok := <-in
+ if !ok {
t.Fatalf("error: expected channel request, got: %#v", err)
return
}
+
// ignore request, always send some text
- ch.AckRequest(true)
+ req.Reply(true, nil)
_, err = io.WriteString(ch, "this-is-stdout.")
if err != nil {
@@ -659,84 +501,39 @@
sendStatus(0, ch, t)
}
-func readLine(shell *ServerTerminal, t *testing.T) {
+func readLine(shell *terminal.Terminal, t *testing.T) {
if _, err := shell.ReadLine(); err != nil && err != io.EOF {
t.Errorf("unable to read line: %v", err)
}
}
-func sendStatus(status uint32, ch *serverChan, t *testing.T) {
+func sendStatus(status uint32, ch Channel, t *testing.T) {
msg := exitStatusMsg{
- PeersId: ch.remoteId,
- Request: "exit-status",
- WantReply: false,
- Status: status,
+ Status: status,
}
- if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+ if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
t.Errorf("unable to send status: %v", err)
}
}
-func sendSignal(signal string, ch *serverChan, t *testing.T) {
+func sendSignal(signal string, ch Channel, t *testing.T) {
sig := exitSignalMsg{
- PeersId: ch.remoteId,
- Request: "exit-signal",
- WantReply: false,
Signal: signal,
CoreDumped: false,
Errmsg: "Process terminated",
Lang: "en-GB-oed",
}
- if err := ch.writePacket(marshal(msgChannelRequest, sig)); err != nil {
+ if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
t.Errorf("unable to send signal: %v", err)
}
}
-func sendInvalidRecord(ch *serverChan, t *testing.T) {
+func discardHandler(ch Channel, t *testing.T) {
defer ch.Close()
- packet := make([]byte, 1+4+4+1)
- packet[0] = msgChannelData
- marshalUint32(packet[1:], 29348723 /* invalid channel id */)
- marshalUint32(packet[5:], 1)
- packet[9] = 42
-
- if err := ch.writePacket(packet); err != nil {
- t.Errorf("unable send invalid record: %v", err)
- }
-}
-
-func sendZeroWindowAdjust(ch *serverChan, t *testing.T) {
- defer ch.Close()
- // send a bogus zero sized window update
- ch.sendWindowAdj(0)
- shell := newServerShell(ch, "> ")
- readLine(shell, t)
- sendStatus(0, ch, t)
-}
-
-func discardHandler(ch *serverChan, t *testing.T) {
- defer ch.Close()
- // grow the window to avoid being fooled by
- // the initial 1 << 14 window.
- ch.sendWindowAdj(1024 * 1024)
io.Copy(ioutil.Discard, ch)
}
-func largeSendHandler(ch *serverChan, t *testing.T) {
- defer ch.Close()
- // grow the window to avoid being fooled by
- // the initial 1 << 14 window.
- ch.sendWindowAdj(1024 * 1024)
- shell := newServerShell(ch, "> ")
- readLine(shell, t)
- // try to send more than the 32k window
- // will allow
- if err := ch.writePacket(make([]byte, 128*1024)); err == nil {
- t.Errorf("wrote packet larger than 32k")
- }
-}
-
-func echoHandler(ch *serverChan, t *testing.T) {
+func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
@@ -773,17 +570,59 @@
return written, nil
}
-func channelKeepaliveSender(ch *serverChan, t *testing.T) {
+func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
defer ch.Close()
- shell := newServerShell(ch, "> ")
+ shell := newServerShell(ch, in, "> ")
readLine(shell, t)
- msg := channelRequestMsg{
- PeersId: ch.remoteId,
- Request: "keepalive@openssh.com",
- WantReply: true,
- }
- if err := ch.writePacket(marshal(msgChannelRequest, msg)); err != nil {
+ if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
t.Errorf("unable to send channel keepalive request: %v", err)
}
sendStatus(0, ch, t)
}
+
+func TestClientWriteEOF(t *testing.T) {
+ conn := dial(simpleEchoHandler, t)
+ defer conn.Close()
+
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer session.Close()
+ stdin, err := session.StdinPipe()
+ if err != nil {
+ t.Fatalf("StdinPipe failed: %v", err)
+ }
+ stdout, err := session.StdoutPipe()
+ if err != nil {
+ t.Fatalf("StdoutPipe failed: %v", err)
+ }
+
+ data := []byte(`0000`)
+ _, err = stdin.Write(data)
+ if err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+ stdin.Close()
+
+ res, err := ioutil.ReadAll(stdout)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ if !bytes.Equal(data, res) {
+ t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
+ }
+}
+
+func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
+ defer ch.Close()
+ data, err := ioutil.ReadAll(ch)
+ if err != nil {
+ t.Errorf("handler read error: %v", err)
+ }
+ _, err = ch.Write(data)
+ if err != nil {
+ t.Errorf("handler write error: %v", err)
+ }
+}
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 74fc1a7..5a4fa8b 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -16,10 +16,11 @@
"time"
)
-// Listen requests the remote peer open a listening socket
-// on addr. Incoming connections will be available by calling
-// Accept on the returned net.Listener.
-func (c *ClientConn) Listen(n, addr string) (net.Listener, error) {
+// Listen requests the remote peer open a listening socket on
+// addr. Incoming connections will be available by calling Accept on
+// the returned net.Listener. The listener must be serviced, or the
+// SSH connection may hang.
+func (c *Client) Listen(n, addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
@@ -59,7 +60,7 @@
// autoPortListenWorkaround simulates automatic port allocation by
// trying random ports repeatedly.
-func (c *ClientConn) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
+func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
var sshListener net.Listener
var err error
const tries = 10
@@ -77,44 +78,45 @@
// RFC 4254 7.1
type channelForwardMsg struct {
- Message string
- WantReply bool
- raddr string
- rport uint32
+ addr string
+ rport uint32
}
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
-func (c *ClientConn) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
- if laddr.Port == 0 && isBrokenOpenSSHVersion(c.serverVersion) {
+func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
+ if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}
m := channelForwardMsg{
- "tcpip-forward",
- true, // sendGlobalRequest waits for a reply
laddr.IP.String(),
uint32(laddr.Port),
}
// send message
- resp, err := c.sendGlobalRequest(m)
+ ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
if err != nil {
return nil, err
}
+ if !ok {
+ return nil, errors.New("ssh: tcpip-forward request denied by peer")
+ }
// If the original port was 0, then the remote side will
// supply a real port number in the response.
if laddr.Port == 0 {
- port, _, ok := parseUint32(resp.Data)
- if !ok {
- return nil, errors.New("unable to parse response")
+ var p struct {
+ Port uint32
}
- laddr.Port = int(port)
+ if err := Unmarshal(resp, &p); err != nil {
+ return nil, err
+ }
+ laddr.Port = int(p.Port)
}
// Register this forward, using the port number we obtained.
- ch := c.forwardList.add(*laddr)
+ ch := c.forwards.add(*laddr)
return &tcpListener{laddr, c, ch}, nil
}
@@ -137,7 +139,7 @@
// arguments to add/remove/lookup should be address as specified in
// the original forward-request.
type forward struct {
- c *clientChan // the ssh client channel underlying this forward
+ newCh NewChannel // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection
}
@@ -152,6 +154,31 @@
return f.c
}
+func (l *forwardList) handleChannels(in <-chan NewChannel) {
+ for ch := range in {
+ laddr, rest, ok := parseTCPAddr(ch.ExtraData())
+ if !ok {
+ // invalid request
+ ch.Reject(ConnectionFailed, "could not parse TCP address")
+ continue
+ }
+
+ raddr, rest, ok := parseTCPAddr(rest)
+ if !ok {
+ // invalid request
+ ch.Reject(ConnectionFailed, "could not parse TCP address")
+ continue
+ }
+
+ if ok := l.forward(*laddr, *raddr, ch); !ok {
+ // Section 7.2, implementations MUST reject spurious incoming
+ // connections.
+ ch.Reject(Prohibited, "no forward for address")
+ continue
+ }
+ }
+}
+
// remove removes the forward entry, and the channel feeding its
// listener.
func (l *forwardList) remove(addr net.TCPAddr) {
@@ -176,21 +203,22 @@
l.entries = nil
}
-func (l *forwardList) lookup(addr net.TCPAddr) (chan forward, bool) {
+func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
- if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
- return f.c, true
+ if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
+ f.c <- forward{ch, &raddr}
+ return true
}
}
- return nil, false
+ return false
}
type tcpListener struct {
laddr *net.TCPAddr
- conn *ClientConn
+ conn *Client
in <-chan forward
}
@@ -200,30 +228,33 @@
if !ok {
return nil, io.EOF
}
+ ch, incoming, err := s.newCh.Accept()
+ if err != nil {
+ return nil, err
+ }
+ go DiscardRequests(incoming)
+
return &tcpChanConn{
- tcpChan: &tcpChan{
- clientChan: s.c,
- Reader: s.c.stdout,
- Writer: s.c.stdin,
- },
- laddr: l.laddr,
- raddr: s.raddr,
+ Channel: ch,
+ laddr: l.laddr,
+ raddr: s.raddr,
}, nil
}
// Close closes the listener.
func (l *tcpListener) Close() error {
m := channelForwardMsg{
- "cancel-tcpip-forward",
- true,
l.laddr.IP.String(),
uint32(l.laddr.Port),
}
- l.conn.forwardList.remove(*l.laddr)
- if _, err := l.conn.sendGlobalRequest(m); err != nil {
- return err
+
+ // this also closes the listener.
+ l.conn.forwards.remove(*l.laddr)
+ ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
+ if err == nil && !ok {
+ err = errors.New("ssh: cancel-tcpip-forward failed")
}
- return nil
+ return err
}
// Addr returns the listener's network address.
@@ -233,7 +264,7 @@
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
-func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
+func (c *Client) Dial(n, addr string) (net.Conn, error) {
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
@@ -253,7 +284,7 @@
return nil, err
}
return &tcpChanConn{
- tcpChan: ch,
+ Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
@@ -262,7 +293,7 @@
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
-func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
+func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
if laddr == nil {
laddr = &net.TCPAddr{
IP: net.IPv4zero,
@@ -274,7 +305,7 @@
return nil, err
}
return &tcpChanConn{
- tcpChan: ch,
+ Channel: ch,
laddr: laddr,
raddr: raddr,
}, nil
@@ -282,54 +313,32 @@
// RFC 4254 7.2
type channelOpenDirectMsg struct {
- ChanType string
- PeersId uint32
- PeersWindow uint32
- MaxPacketSize uint32
- raddr string
- rport uint32
- laddr string
- lport uint32
+ raddr string
+ rport uint32
+ laddr string
+ lport uint32
}
-// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
-// strings and are expected to be resolvable at the remote end.
-func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
- ch := c.newChan(c.transport)
- if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
- ChanType: "direct-tcpip",
- PeersId: ch.localId,
- PeersWindow: channelWindowSize,
- MaxPacketSize: channelMaxPacketSize,
- raddr: raddr,
- rport: uint32(rport),
- laddr: laddr,
- lport: uint32(lport),
- })); err != nil {
- c.chanList.remove(ch.localId)
- return nil, err
+func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
+ msg := channelOpenDirectMsg{
+ raddr: raddr,
+ rport: uint32(rport),
+ laddr: laddr,
+ lport: uint32(lport),
}
- if err := ch.waitForChannelOpenResponse(); err != nil {
- c.chanList.remove(ch.localId)
- return nil, fmt.Errorf("ssh: unable to open direct tcpip connection: %v", err)
- }
- return &tcpChan{
- clientChan: ch,
- Reader: ch.stdout,
- Writer: ch.stdin,
- }, nil
+ ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
+ go DiscardRequests(in)
+ return ch, err
}
type tcpChan struct {
- *clientChan // the backing channel
- io.Reader
- io.Writer
+ Channel // the backing channel
}
// tcpChanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly.
type tcpChanConn struct {
- *tcpChan
+ Channel
laddr, raddr net.Addr
}
diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go
index 7fa9fc4..f1265cb 100644
--- a/ssh/tcpip_test.go
+++ b/ssh/tcpip_test.go
@@ -1,3 +1,7 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
package ssh
import (
diff --git a/ssh/terminal/util_windows.go b/ssh/terminal/util_windows.go
new file mode 100644
index 0000000..0a454e0
--- /dev/null
+++ b/ssh/terminal/util_windows.go
@@ -0,0 +1,171 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build windows
+
+// Package terminal provides support functions for dealing with terminals, as
+// commonly found on UNIX systems.
+//
+// Putting a terminal into raw mode is the most common requirement:
+//
+// oldState, err := terminal.MakeRaw(0)
+// if err != nil {
+// panic(err)
+// }
+// defer terminal.Restore(0, oldState)
+package terminal
+
+import (
+ "io"
+ "syscall"
+ "unsafe"
+)
+
+const (
+ enableLineInput = 2
+ enableEchoInput = 4
+ enableProcessedInput = 1
+ enableWindowInput = 8
+ enableMouseInput = 16
+ enableInsertMode = 32
+ enableQuickEditMode = 64
+ enableExtendedFlags = 128
+ enableAutoPosition = 256
+ enableProcessedOutput = 1
+ enableWrapAtEolOutput = 2
+)
+
+var kernel32 = syscall.NewLazyDLL("kernel32.dll")
+
+var (
+ procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
+ procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
+ procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
+)
+
+type (
+ short int16
+ word uint16
+
+ coord struct {
+ x short
+ y short
+ }
+ smallRect struct {
+ left short
+ top short
+ right short
+ bottom short
+ }
+ consoleScreenBufferInfo struct {
+ size coord
+ cursorPosition coord
+ attributes word
+ window smallRect
+ maximumWindowSize coord
+ }
+)
+
+type State struct {
+ mode uint32
+}
+
+// IsTerminal returns true if the given file descriptor is a terminal.
+func IsTerminal(fd int) bool {
+ var st uint32
+ r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+ return r != 0 && e == 0
+}
+
+// MakeRaw put the terminal connected to the given file descriptor into raw
+// mode and returns the previous state of the terminal so that it can be
+// restored.
+func MakeRaw(fd int) (*State, error) {
+ var st uint32
+ _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+ if e != 0 {
+ return nil, error(e)
+ }
+ st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
+ _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
+ if e != 0 {
+ return nil, error(e)
+ }
+ return &State{st}, nil
+}
+
+// GetState returns the current state of a terminal which may be useful to
+// restore the terminal after a signal.
+func GetState(fd int) (*State, error) {
+ var st uint32
+ _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+ if e != 0 {
+ return nil, error(e)
+ }
+ return &State{st}, nil
+}
+
+// Restore restores the terminal connected to the given file descriptor to a
+// previous state.
+func Restore(fd int, state *State) error {
+ _, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0)
+ return err
+}
+
+// GetSize returns the dimensions of the given terminal.
+func GetSize(fd int) (width, height int, err error) {
+ var info consoleScreenBufferInfo
+ _, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0)
+ if e != 0 {
+ return 0, 0, error(e)
+ }
+ return int(info.size.x), int(info.size.y), nil
+}
+
+// ReadPassword reads a line of input from a terminal without local echo. This
+// is commonly used for inputting passwords and other sensitive data. The slice
+// returned does not include the \n.
+func ReadPassword(fd int) ([]byte, error) {
+ var st uint32
+ _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
+ if e != 0 {
+ return nil, error(e)
+ }
+ old := st
+
+ st &^= (enableEchoInput)
+ st |= (enableProcessedInput | enableLineInput | enableProcessedOutput)
+ _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
+ if e != 0 {
+ return nil, error(e)
+ }
+
+ defer func() {
+ syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
+ }()
+
+ var buf [16]byte
+ var ret []byte
+ for {
+ n, err := syscall.Read(syscall.Handle(fd), buf[:])
+ if err != nil {
+ return nil, err
+ }
+ if n == 0 {
+ if len(ret) == 0 {
+ return nil, io.EOF
+ }
+ break
+ }
+ if buf[n-1] == '\n' {
+ n--
+ }
+ ret = append(ret, buf[:n]...)
+ if n < len(buf) {
+ break
+ }
+ }
+
+ return ret, nil
+}
diff --git a/ssh/test/agent_unix_test.go b/ssh/test/agent_unix_test.go
new file mode 100644
index 0000000..26c88eb
--- /dev/null
+++ b/ssh/test/agent_unix_test.go
@@ -0,0 +1,50 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package test
+
+import (
+ "bytes"
+ "testing"
+
+ "code.google.com/p/go.crypto/ssh"
+ "code.google.com/p/go.crypto/ssh/agent"
+)
+
+func TestAgentForward(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conn := server.Dial(clientConfig())
+ defer conn.Close()
+
+ keyring := agent.NewKeyring()
+ keyring.Add(testPrivateKeys["dsa"], nil, "")
+ pub := testPublicKeys["dsa"]
+
+ sess, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("NewSession: %v", err)
+ }
+ if err := agent.RequestAgentForwarding(sess); err != nil {
+ t.Fatalf("RequestAgentForwarding: %v", err)
+ }
+
+ if err := agent.ForwardToAgent(conn, keyring); err != nil {
+ t.Fatalf("SetupForwardKeyring: %v", err)
+ }
+ out, err := sess.CombinedOutput("ssh-add -L")
+ if err != nil {
+ t.Fatalf("running ssh-add: %v, out %s", err, out)
+ }
+ key, _, _, _, err := ssh.ParseAuthorizedKey(out)
+ if err != nil {
+ t.Fatalf("ParseAuthorizedKey(%q): %v", out, err)
+ }
+
+ if !bytes.Equal(key.Marshal(), pub.Marshal()) {
+ t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub))
+ }
+}
diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go
new file mode 100644
index 0000000..d4f7226
--- /dev/null
+++ b/ssh/test/cert_test.go
@@ -0,0 +1,47 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin freebsd linux netbsd openbsd
+
+package test
+
+import (
+ "crypto/rand"
+ "testing"
+
+ "code.google.com/p/go.crypto/ssh"
+)
+
+func TestCertLogin(t *testing.T) {
+ s := newServer(t)
+ defer s.Shutdown()
+
+ // Use a key different from the default.
+ clientKey := testSigners["dsa"]
+ caAuthKey := testSigners["ecdsa"]
+ cert := &ssh.Certificate{
+ Key: clientKey.PublicKey(),
+ ValidPrincipals: []string{username()},
+ CertType: ssh.UserCert,
+ ValidBefore: ssh.CertTimeInfinity,
+ }
+ if err := cert.SignCert(rand.Reader, caAuthKey); err != nil {
+ t.Fatalf("SetSignature: %v", err)
+ }
+
+ certSigner, err := ssh.NewCertSigner(cert, clientKey)
+ if err != nil {
+ t.Fatalf("NewCertSigner: %v", err)
+ }
+
+ conf := &ssh.ClientConfig{
+ User: username(),
+ }
+ conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
+ client, err := s.TryDial(conf)
+ if err != nil {
+ t.Fatalf("TryDial: %v", err)
+ }
+ client.Close()
+}
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index 3a57c10..881a9da 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux netbsd openbsd plan9
+// +build darwin freebsd linux netbsd openbsd
package test
diff --git a/ssh/test/session_test.go b/ssh/test/session_test.go
index bd7307d..d8d35a5 100644
--- a/ssh/test/session_test.go
+++ b/ssh/test/session_test.go
@@ -11,6 +11,7 @@
import (
"bytes"
"code.google.com/p/go.crypto/ssh"
+ "errors"
"io"
"strings"
"testing"
@@ -38,12 +39,13 @@
defer server.Shutdown()
conf := clientConfig()
- k := conf.HostKeyChecker.(*storedHostKey)
+ hostDB := hostKeyDB()
+ conf.HostKeyCallback = hostDB.Check
// change the keys.
- k.keys[ssh.KeyAlgoRSA][25]++
- k.keys[ssh.KeyAlgoDSA][25]++
- k.keys[ssh.KeyAlgoECDSA256][25]++
+ hostDB.keys[ssh.KeyAlgoRSA][25]++
+ hostDB.keys[ssh.KeyAlgoDSA][25]++
+ hostDB.keys[ssh.KeyAlgoECDSA256][25]++
conn, err := server.TryDial(conf)
if err == nil {
@@ -54,6 +56,53 @@
}
}
+func TestRunCommandStdin(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conn := server.Dial(clientConfig())
+ defer conn.Close()
+
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("session failed: %v", err)
+ }
+ defer session.Close()
+
+ r, w := io.Pipe()
+ defer r.Close()
+ defer w.Close()
+ session.Stdin = r
+
+ err = session.Run("true")
+ if err != nil {
+ t.Fatalf("session failed: %v", err)
+ }
+}
+
+func TestRunCommandStdinError(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conn := server.Dial(clientConfig())
+ defer conn.Close()
+
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("session failed: %v", err)
+ }
+ defer session.Close()
+
+ r, w := io.Pipe()
+ defer r.Close()
+ session.Stdin = r
+ pipeErr := errors.New("closing write end of pipe")
+ w.CloseWithError(pipeErr)
+
+ err = session.Run("true")
+ if err != pipeErr {
+ t.Fatalf("expected %v, found %v", pipeErr, err)
+ }
+}
+
func TestRunCommandFailed(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
@@ -107,7 +156,7 @@
t.Fatalf("unable to acquire stdout pipe: %s", err)
}
- err = session.Start("dd if=/dev/urandom bs=2048 count=1")
+ err = session.Start("dd if=/dev/urandom bs=2048 count=1024")
if err != nil {
t.Fatalf("unable to execute remote command: %s", err)
}
@@ -118,11 +167,53 @@
t.Fatalf("error reading from remote stdout: %s", err)
}
- if n != 2048 {
+ if n != 2048*1024 {
t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n)
}
}
+func TestKeyChange(t *testing.T) {
+ server := newServer(t)
+ defer server.Shutdown()
+ conf := clientConfig()
+ hostDB := hostKeyDB()
+ conf.HostKeyCallback = hostDB.Check
+ conf.RekeyThreshold = 1024
+ conn := server.Dial(conf)
+ defer conn.Close()
+
+ for i := 0; i < 4; i++ {
+ session, err := conn.NewSession()
+ if err != nil {
+ t.Fatalf("unable to create new session: %s", err)
+ }
+
+ stdout, err := session.StdoutPipe()
+ if err != nil {
+ t.Fatalf("unable to acquire stdout pipe: %s", err)
+ }
+
+ err = session.Start("dd if=/dev/urandom bs=1024 count=1")
+ if err != nil {
+ t.Fatalf("unable to execute remote command: %s", err)
+ }
+ buf := new(bytes.Buffer)
+ n, err := io.Copy(buf, stdout)
+ if err != nil {
+ t.Fatalf("error reading from remote stdout: %s", err)
+ }
+
+ want := int64(1024)
+ if n != want {
+ t.Fatalf("Expected %d bytes but read only %d from remote command", want, n)
+ }
+ }
+
+ if changes := hostDB.checkCount; changes < 4 {
+ t.Errorf("got %d key changes, want 4", changes)
+ }
+}
+
func TestInvalidTerminalMode(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
@@ -183,3 +274,44 @@
t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput)
}
}
+
+func TestCiphers(t *testing.T) {
+ var config ssh.Config
+ config.SetDefaults()
+ cipherOrder := config.Ciphers
+
+ for _, ciph := range cipherOrder {
+ server := newServer(t)
+ defer server.Shutdown()
+ conf := clientConfig()
+ conf.Ciphers = []string{ciph}
+ // Don't fail if sshd doesnt have the cipher.
+ conf.Ciphers = append(conf.Ciphers, cipherOrder...)
+ conn, err := server.TryDial(conf)
+ if err == nil {
+ conn.Close()
+ } else {
+ t.Fatalf("failed for cipher %q", ciph)
+ }
+ }
+}
+
+func TestMACs(t *testing.T) {
+ var config ssh.Config
+ config.SetDefaults()
+ macOrder := config.MACs
+
+ for _, mac := range macOrder {
+ server := newServer(t)
+ defer server.Shutdown()
+ conf := clientConfig()
+ conf.MACs = []string{mac}
+ // Don't fail if sshd doesnt have the MAC.
+ conf.MACs = append(conf.MACs, macOrder...)
+ if conn, err := server.TryDial(conf); err == nil {
+ conn.Close()
+ } else {
+ t.Fatalf("failed for MAC %q", mac)
+ }
+ }
+}
diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go
index ee06b60..a2eb935 100644
--- a/ssh/test/tcpip_test.go
+++ b/ssh/test/tcpip_test.go
@@ -9,39 +9,38 @@
// direct-tcpip functional tests
import (
+ "io"
"net"
- "net/http"
"testing"
)
-func TestTCPIPHTTP(t *testing.T) {
- // google.com will generate at least one redirect, possibly three
- // depending on your location.
- doTest(t, "http://google.com")
-}
-
-func TestTCPIPHTTPS(t *testing.T) {
- doTest(t, "https://encrypted.google.com/")
-}
-
-func doTest(t *testing.T, url string) {
+func TestDial(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
- conn := server.Dial(clientConfig())
- defer conn.Close()
+ sshConn := server.Dial(clientConfig())
+ defer sshConn.Close()
- tr := &http.Transport{
- Dial: func(n, addr string) (net.Conn, error) {
- return conn.Dial(n, addr)
- },
- }
- client := &http.Client{
- Transport: tr,
- }
- resp, err := client.Get(url)
+ l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
- t.Fatalf("unable to proxy: %s", err)
+ t.Fatalf("Listen: %v", err)
}
- // got a body without error
- t.Log(resp)
+ defer l.Close()
+
+ go func() {
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ break
+ }
+
+ io.WriteString(c, c.RemoteAddr().String())
+ c.Close()
+ }
+ }()
+
+ conn, err := sshConn.Dial("tcp", l.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
}
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 86df3f4..f44c65d 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin freebsd linux netbsd openbsd plan9
+// +build darwin freebsd linux netbsd openbsd
package test
@@ -11,7 +11,6 @@
import (
"bytes"
"fmt"
- "io"
"io/ioutil"
"log"
"net"
@@ -23,13 +22,14 @@
"text/template"
"code.google.com/p/go.crypto/ssh"
+ "code.google.com/p/go.crypto/ssh/testdata"
)
const sshd_config = `
Protocol 2
-HostKey {{.Dir}}/ssh_host_rsa_key
-HostKey {{.Dir}}/ssh_host_dsa_key
-HostKey {{.Dir}}/ssh_host_ecdsa_key
+HostKey {{.Dir}}/id_rsa
+HostKey {{.Dir}}/id_dsa
+HostKey {{.Dir}}/id_ecdsa
Pidfile {{.Dir}}/sshd.pid
#UsePrivilegeSeparation no
KeyRegenerationInterval 3600
@@ -41,41 +41,14 @@
StrictModes no
RSAAuthentication yes
PubkeyAuthentication yes
-AuthorizedKeysFile {{.Dir}}/authorized_keys
+AuthorizedKeysFile {{.Dir}}/id_user.pub
+TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
IgnoreRhosts yes
RhostsRSAAuthentication no
HostbasedAuthentication no
`
-var (
- configTmpl template.Template
- privateKey ssh.Signer
- hostKeyRSA ssh.Signer
- hostKeyECDSA ssh.Signer
- hostKeyDSA ssh.Signer
-)
-
-func init() {
- template.Must(configTmpl.Parse(sshd_config))
-
- for n, k := range map[string]*ssh.Signer{
- "ssh_host_ecdsa_key": &hostKeyECDSA,
- "ssh_host_rsa_key": &hostKeyRSA,
- "ssh_host_dsa_key": &hostKeyDSA,
- } {
- var err error
- *k, err = ssh.ParsePrivateKey([]byte(keys[n]))
- if err != nil {
- panic(fmt.Sprintf("ParsePrivateKey(%q): %v", n, err))
- }
- }
-
- var err error
- privateKey, err = ssh.ParsePrivateKey([]byte(testClientPrivateKey))
- if err != nil {
- panic(fmt.Sprintf("ParsePrivateKey: %v", err))
- }
-}
+var configTmpl = template.Must(template.New("").Parse(sshd_config))
type server struct {
t *testing.T
@@ -107,36 +80,44 @@
type storedHostKey struct {
// keys map from an algorithm string to binary key data.
keys map[string][]byte
+
+ // checkCount counts the Check calls. Used for testing
+ // rekeying.
+ checkCount int
}
func (k *storedHostKey) Add(key ssh.PublicKey) {
if k.keys == nil {
k.keys = map[string][]byte{}
}
- k.keys[key.PublicKeyAlgo()] = ssh.MarshalPublicKey(key)
+ k.keys[key.Type()] = key.Marshal()
}
-func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
- if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
+func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
+ k.checkCount++
+ algo := key.Type()
+
+ if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
}
return nil
}
-func clientConfig() *ssh.ClientConfig {
- keyChecker := storedHostKey{}
- keyChecker.Add(hostKeyECDSA.PublicKey())
- keyChecker.Add(hostKeyRSA.PublicKey())
- keyChecker.Add(hostKeyDSA.PublicKey())
+func hostKeyDB() *storedHostKey {
+ keyChecker := &storedHostKey{}
+ keyChecker.Add(testPublicKeys["ecdsa"])
+ keyChecker.Add(testPublicKeys["rsa"])
+ keyChecker.Add(testPublicKeys["dsa"])
+ return keyChecker
+}
- kc := new(keychain)
- kc.keys = append(kc.keys, privateKey)
+func clientConfig() *ssh.ClientConfig {
config := &ssh.ClientConfig{
User: username(),
- Auth: []ssh.ClientAuth{
- ssh.ClientAuthKeyring(kc),
+ Auth: []ssh.AuthMethod{
+ ssh.PublicKeys(testSigners["user"]),
},
- HostKeyChecker: &keyChecker,
+ HostKeyCallback: hostKeyDB().Check,
}
return config
}
@@ -171,7 +152,7 @@
return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
}
-func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
+func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
@@ -197,10 +178,14 @@
s.t.Fatalf("s.cmd.Start: %v", err)
}
s.clientConn = c1
- return ssh.Client(c1, config)
+ conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
+ if err != nil {
+ return nil, err
+ }
+ return ssh.NewClient(conn, chans, reqs), nil
}
-func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
+func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
conn, err := s.TryDial(config)
if err != nil {
s.t.Fail()
@@ -226,6 +211,17 @@
s.cleanup()
}
+func writeFile(path string, contents []byte) {
+ f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
+ if err != nil {
+ panic(err)
+ }
+ defer f.Close()
+ if _, err := f.Write(contents); err != nil {
+ panic(err)
+ }
+}
+
// newServer returns a new mock ssh server.
func newServer(t *testing.T) *server {
dir, err := ioutil.TempDir("", "sshtest")
@@ -244,15 +240,10 @@
}
f.Close()
- for k, v := range keys {
- f, err := os.OpenFile(filepath.Join(dir, k), os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
- if err != nil {
- t.Fatal(err)
- }
- if _, err := f.Write([]byte(v)); err != nil {
- t.Fatal(err)
- }
- f.Close()
+ for k, v := range testdata.PEMBytes {
+ filename := "id_" + k
+ writeFile(filepath.Join(dir, filename), v)
+ writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
}
return &server{
@@ -265,32 +256,3 @@
},
}
}
-
-// keychain implements the ClientKeyring interface.
-type keychain struct {
- keys []ssh.Signer
-}
-
-func (k *keychain) Key(i int) (ssh.PublicKey, error) {
- if i < 0 || i >= len(k.keys) {
- return nil, nil
- }
- return k.keys[i].PublicKey(), nil
-}
-
-func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
- return k.keys[i].Sign(rand, data)
-}
-
-func (k *keychain) loadPEM(file string) error {
- buf, err := ioutil.ReadFile(file)
- if err != nil {
- return err
- }
- key, err := ssh.ParsePrivateKey(buf)
- if err != nil {
- return err
- }
- k.keys = append(k.keys, key)
- return nil
-}
diff --git a/ssh/test/testdata_test.go b/ssh/test/testdata_test.go
new file mode 100644
index 0000000..7f50fbe
--- /dev/null
+++ b/ssh/test/testdata_test.go
@@ -0,0 +1,64 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
+// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
+// instances.
+
+package test
+
+import (
+ "crypto/rand"
+ "fmt"
+
+ "code.google.com/p/go.crypto/ssh"
+ "code.google.com/p/go.crypto/ssh/testdata"
+)
+
+var (
+ testPrivateKeys map[string]interface{}
+ testSigners map[string]ssh.Signer
+ testPublicKeys map[string]ssh.PublicKey
+)
+
+func init() {
+ var err error
+
+ n := len(testdata.PEMBytes)
+ testPrivateKeys = make(map[string]interface{}, n)
+ testSigners = make(map[string]ssh.Signer, n)
+ testPublicKeys = make(map[string]ssh.PublicKey, n)
+ for t, k := range testdata.PEMBytes {
+ testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
+ }
+ testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t])
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
+ }
+ testPublicKeys[t] = testSigners[t].PublicKey()
+ }
+
+ // Create a cert and sign it for use in tests.
+ testCert := &ssh.Certificate{
+ Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+ ValidAfter: 0, // unix epoch
+ ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
+ Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ Key: testPublicKeys["ecdsa"],
+ SignatureKey: testPublicKeys["rsa"],
+ Permissions: ssh.Permissions{
+ CriticalOptions: map[string]string{},
+ Extensions: map[string]string{},
+ },
+ }
+ testCert.SignCert(rand.Reader, testSigners["rsa"])
+ testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
+ testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"])
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
+ }
+}
diff --git a/ssh/testdata/doc.go b/ssh/testdata/doc.go
new file mode 100644
index 0000000..4302486
--- /dev/null
+++ b/ssh/testdata/doc.go
@@ -0,0 +1,8 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This package contains test data shared between the various subpackages of
+// the code.google.com/p/go.crypto/ssh package. Under no circumstance should
+// this data be used for production code.
+package testdata
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go
new file mode 100644
index 0000000..5ff1c0e
--- /dev/null
+++ b/ssh/testdata/keys.go
@@ -0,0 +1,43 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package testdata
+
+var PEMBytes = map[string][]byte{
+ "dsa": []byte(`-----BEGIN DSA PRIVATE KEY-----
+MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB
+lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3
+EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD
+nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV
+2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r
+juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr
+FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz
+DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj
+nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY
+Fmsr0W6fHB9nhS4/UXM8
+-----END DSA PRIVATE KEY-----
+`),
+ "ecdsa": []byte(`-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
+AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
+6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
+-----END EC PRIVATE KEY-----
+`),
+ "rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
+r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
+tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
+nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
+2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
+y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
+rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
+-----END RSA PRIVATE KEY-----
+`),
+ "user": []byte(`-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
+AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD
+PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w==
+-----END EC PRIVATE KEY-----
+`),
+}
diff --git a/ssh/testdata_test.go b/ssh/testdata_test.go
new file mode 100644
index 0000000..302fdc8
--- /dev/null
+++ b/ssh/testdata_test.go
@@ -0,0 +1,63 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
+// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
+// instances.
+
+package ssh
+
+import (
+ "crypto/rand"
+ "fmt"
+
+ "code.google.com/p/go.crypto/ssh/testdata"
+)
+
+var (
+ testPrivateKeys map[string]interface{}
+ testSigners map[string]Signer
+ testPublicKeys map[string]PublicKey
+)
+
+func init() {
+ var err error
+
+ n := len(testdata.PEMBytes)
+ testPrivateKeys = make(map[string]interface{}, n)
+ testSigners = make(map[string]Signer, n)
+ testPublicKeys = make(map[string]PublicKey, n)
+ for t, k := range testdata.PEMBytes {
+ testPrivateKeys[t], err = ParseRawPrivateKey(k)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err))
+ }
+ testSigners[t], err = NewSignerFromKey(testPrivateKeys[t])
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err))
+ }
+ testPublicKeys[t] = testSigners[t].PublicKey()
+ }
+
+ // Create a cert and sign it for use in tests.
+ testCert := &Certificate{
+ Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+ ValidAfter: 0, // unix epoch
+ ValidBefore: CertTimeInfinity, // The end of currently representable time.
+ Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ Key: testPublicKeys["ecdsa"],
+ SignatureKey: testPublicKeys["rsa"],
+ Permissions: Permissions{
+ CriticalOptions: map[string]string{},
+ Extensions: map[string]string{},
+ },
+ }
+ testCert.SignCert(rand.Reader, testSigners["rsa"])
+ testPrivateKeys["cert"] = testPrivateKeys["ecdsa"]
+ testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"])
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create certificate signer: %v", err))
+ }
+}
diff --git a/ssh/transport.go b/ssh/transport.go
index 46fa262..4f68b04 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -6,26 +6,12 @@
import (
"bufio"
- "crypto/cipher"
- "crypto/subtle"
- "encoding/binary"
"errors"
- "hash"
"io"
- "net"
- "sync"
)
const (
- packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
-
- // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
- // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
- // indicates implementations SHOULD be able to handle larger packet sizes, but then
- // waffles on about reasonable limits.
- //
- // OpenSSH caps their maxPacket at 256kb so we choose to do the same.
- maxPacket = 256 * 1024
+ gcmCipherID = "aes128-gcm@openssh.com"
)
// packetConn represents a transport that implements packet based
@@ -41,225 +27,128 @@
Close() error
}
-// transport represents the SSH connection to the remote peer.
+// transport is the keyingTransport that implements the SSH packet
+// protocol.
type transport struct {
- reader
- writer
+ reader connectionState
+ writer connectionState
- net.Conn
+ bufReader *bufio.Reader
+ bufWriter *bufio.Writer
+ rand io.Reader
+
+ io.Closer
// Initial H used for the session ID. Once assigned this does
// not change, even during subsequent key exchanges.
sessionID []byte
}
-// reader represents the incoming connection state.
-type reader struct {
- io.Reader
- common
+func (t *transport) getSessionID() []byte {
+ if t.sessionID == nil {
+ panic("session ID not set yet")
+ }
+ s := make([]byte, len(t.sessionID))
+ copy(s, t.sessionID)
+ return s
}
-// writer represents the outgoing connection state.
-type writer struct {
- sync.Mutex // protects writer.Writer from concurrent writes
- *bufio.Writer
- rand io.Reader
- common
+// packetCipher represents a combination of SSH encryption/MAC
+// protocol. A single instance should be used for one direction only.
+type packetCipher interface {
+ // writePacket encrypts the packet and writes it to w. The
+ // contents of the packet are generally scrambled.
+ writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
+
+ // readPacket reads and decrypts a packet of data. The
+ // returned packet may be overwritten by future calls of
+ // readPacket.
+ readPacket(seqnum uint32, r io.Reader) ([]byte, error)
+}
+
+// connectionState represents one side (read or write) of the
+// connection. This is necessary because each direction has its own
+// keys, and can even have its own algorithms
+type connectionState struct {
+ packetCipher
+ seqNum uint32
+ dir direction
+ pendingKeyChange chan packetCipher
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
- t.writer.cipherAlgo = algs.wCipher
- t.writer.macAlgo = algs.wMAC
- t.writer.compressionAlgo = algs.wCompression
-
- t.reader.cipherAlgo = algs.rCipher
- t.reader.macAlgo = algs.rMAC
- t.reader.compressionAlgo = algs.rCompression
-
if t.sessionID == nil {
t.sessionID = kexResult.H
}
kexResult.SessionID = t.sessionID
- t.reader.pendingKeyChange <- kexResult
- t.writer.pendingKeyChange <- kexResult
+
+ if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
+ return err
+ } else {
+ t.reader.pendingKeyChange <- ciph
+ }
+
+ if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
+ return err
+ } else {
+ t.writer.pendingKeyChange <- ciph
+ }
+
return nil
}
-// common represents the cipher state needed to process messages in a single
-// direction.
-type common struct {
- seqNum uint32
- mac hash.Hash
- cipher cipher.Stream
-
- cipherAlgo string
- macAlgo string
- compressionAlgo string
-
- dir direction
- pendingKeyChange chan *kexResult
+// Read and decrypt next packet.
+func (t *transport) readPacket() ([]byte, error) {
+ return t.reader.readPacket(t.bufReader)
}
-// Read and decrypt a single packet from the remote peer.
-func (r *reader) readPacket() ([]byte, error) {
- var lengthBytes = make([]byte, 5)
- var macSize uint32
- if _, err := io.ReadFull(r, lengthBytes); err != nil {
- return nil, err
+func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
+ packet, err := s.packetCipher.readPacket(s.seqNum, r)
+ s.seqNum++
+ if err == nil && len(packet) == 0 {
+ err = errors.New("ssh: zero length packet")
}
- r.cipher.XORKeyStream(lengthBytes, lengthBytes)
-
- if r.mac != nil {
- r.mac.Reset()
- seqNumBytes := []byte{
- byte(r.seqNum >> 24),
- byte(r.seqNum >> 16),
- byte(r.seqNum >> 8),
- byte(r.seqNum),
- }
- r.mac.Write(seqNumBytes)
- r.mac.Write(lengthBytes)
- macSize = uint32(r.mac.Size())
- }
-
- length := binary.BigEndian.Uint32(lengthBytes[0:4])
- paddingLength := uint32(lengthBytes[4])
-
- if length <= paddingLength+1 {
- return nil, errors.New("ssh: invalid packet length, packet too small")
- }
-
- if length > maxPacket {
- return nil, errors.New("ssh: invalid packet length, packet too large")
- }
-
- packet := make([]byte, length-1+macSize)
- if _, err := io.ReadFull(r, packet); err != nil {
- return nil, err
- }
- mac := packet[length-1:]
- r.cipher.XORKeyStream(packet, packet[:length-1])
-
- if r.mac != nil {
- r.mac.Write(packet[:length-1])
- if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
- return nil, errors.New("ssh: MAC failure")
- }
- }
-
- r.seqNum++
- packet = packet[:length-paddingLength-1]
-
if len(packet) > 0 && packet[0] == msgNewKeys {
select {
- case k := <-r.pendingKeyChange:
- if err := r.setupKeys(r.dir, k); err != nil {
- return nil, err
- }
+ case cipher := <-s.pendingKeyChange:
+ s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
}
}
- return packet, nil
+
+ // The packet may point to an internal buffer, so copy the
+ // packet out here.
+ fresh := make([]byte, len(packet))
+ copy(fresh, packet)
+
+ return fresh, err
}
-// Read and decrypt next packet discarding debug and noop messages.
-func (t *transport) readPacket() ([]byte, error) {
- for {
- packet, err := t.reader.readPacket()
- if err != nil {
- return nil, err
- }
- if len(packet) == 0 {
- return nil, errors.New("ssh: zero length packet")
- }
-
- if packet[0] != msgIgnore && packet[0] != msgDebug {
- return packet, nil
- }
- }
- panic("unreachable")
+func (t *transport) writePacket(packet []byte) error {
+ return t.writer.writePacket(t.bufWriter, t.rand, packet)
}
-// Encrypt and send a packet of data to the remote peer.
-func (w *writer) writePacket(packet []byte) error {
+func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
- if len(packet) > maxPacket {
- return errors.New("ssh: packet too large")
- }
- w.Mutex.Lock()
- defer w.Mutex.Unlock()
-
- paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
- if paddingLength < 4 {
- paddingLength += packetSizeMultiple
- }
-
- length := len(packet) + 1 + paddingLength
- lengthBytes := []byte{
- byte(length >> 24),
- byte(length >> 16),
- byte(length >> 8),
- byte(length),
- byte(paddingLength),
- }
- padding := make([]byte, paddingLength)
- _, err := io.ReadFull(w.rand, padding)
+ err := s.packetCipher.writePacket(s.seqNum, w, rand, packet)
if err != nil {
return err
}
-
- if w.mac != nil {
- w.mac.Reset()
- seqNumBytes := []byte{
- byte(w.seqNum >> 24),
- byte(w.seqNum >> 16),
- byte(w.seqNum >> 8),
- byte(w.seqNum),
- }
- w.mac.Write(seqNumBytes)
- w.mac.Write(lengthBytes)
- w.mac.Write(packet)
- w.mac.Write(padding)
- }
-
- // TODO(dfc) lengthBytes, packet and padding should be
- // subslices of a single buffer
- w.cipher.XORKeyStream(lengthBytes, lengthBytes)
- w.cipher.XORKeyStream(packet, packet)
- w.cipher.XORKeyStream(padding, padding)
-
- if _, err := w.Write(lengthBytes); err != nil {
- return err
- }
- if _, err := w.Write(packet); err != nil {
- return err
- }
- if _, err := w.Write(padding); err != nil {
- return err
- }
-
- if w.mac != nil {
- if _, err := w.Write(w.mac.Sum(nil)); err != nil {
- return err
- }
- }
-
- w.seqNum++
if err = w.Flush(); err != nil {
return err
}
-
+ s.seqNum++
if changeKeys {
select {
- case k := <-w.pendingKeyChange:
- err = w.setupKeys(w.dir, k)
+ case cipher := <-s.pendingKeyChange:
+ s.packetCipher = cipher
default:
panic("ssh: no key material for msgNewKeys")
}
@@ -267,24 +156,20 @@
return err
}
-func newTransport(conn net.Conn, rand io.Reader, isClient bool) *transport {
+func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
t := &transport{
- reader: reader{
- Reader: bufio.NewReader(conn),
- common: common{
- cipher: noneCipher{},
- pendingKeyChange: make(chan *kexResult, 1),
- },
+ bufReader: bufio.NewReader(rwc),
+ bufWriter: bufio.NewWriter(rwc),
+ rand: rand,
+ reader: connectionState{
+ packetCipher: &streamPacketCipher{cipher: noneCipher{}},
+ pendingKeyChange: make(chan packetCipher, 1),
},
- writer: writer{
- Writer: bufio.NewWriter(conn),
- rand: rand,
- common: common{
- cipher: noneCipher{},
- pendingKeyChange: make(chan *kexResult, 1),
- },
+ writer: connectionState{
+ packetCipher: &streamPacketCipher{cipher: noneCipher{}},
+ pendingKeyChange: make(chan packetCipher, 1),
},
- Conn: conn,
+ Closer: rwc,
}
if isClient {
t.reader.dir = serverKeys
@@ -303,48 +188,64 @@
macKeyTag []byte
}
-// TODO(dfc) can this be made a constant ?
var (
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
+// generateKeys generates key material for IV, MAC and encryption.
+func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
+ cipherMode := cipherModes[algs.Cipher]
+ macMode := macModes[algs.MAC]
+
+ iv = make([]byte, cipherMode.ivSize)
+ key = make([]byte, cipherMode.keySize)
+ macKey = make([]byte, macMode.keySize)
+
+ generateKeyMaterial(iv, d.ivTag, kex)
+ generateKeyMaterial(key, d.keyTag, kex)
+ generateKeyMaterial(macKey, d.macKeyTag, kex)
+ return
+}
+
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
-func (c *common) setupKeys(d direction, r *kexResult) error {
- cipherMode := cipherModes[c.cipherAlgo]
- macMode := macModes[c.macAlgo]
+func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
+ iv, key, macKey := generateKeys(d, algs, kex)
- iv := make([]byte, cipherMode.ivSize)
- key := make([]byte, cipherMode.keySize)
- macKey := make([]byte, macMode.keySize)
+ if algs.Cipher == gcmCipherID {
+ return newGCMCipher(iv, key, macKey)
+ }
- h := r.Hash.New()
- generateKeyMaterial(iv, d.ivTag, r.K, r.H, r.SessionID, h)
- generateKeyMaterial(key, d.keyTag, r.K, r.H, r.SessionID, h)
- generateKeyMaterial(macKey, d.macKeyTag, r.K, r.H, r.SessionID, h)
-
- c.mac = macMode.new(macKey)
+ c := &streamPacketCipher{
+ mac: macModes[algs.MAC].new(macKey),
+ }
+ c.macResult = make([]byte, c.mac.Size())
var err error
- c.cipher, err = cipherMode.createCipher(key, iv)
- return err
+ c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
+ if err != nil {
+ return nil, err
+ }
+
+ return c, nil
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
-func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
+func generateKeyMaterial(out, tag []byte, r *kexResult) {
var digestsSoFar []byte
+ h := r.Hash.New()
for len(out) > 0 {
h.Reset()
- h.Write(K)
- h.Write(H)
+ h.Write(r.K)
+ h.Write(r.H)
if len(digestsSoFar) == 0 {
h.Write(tag)
- h.Write(sessionId)
+ h.Write(r.SessionID)
} else {
h.Write(digestsSoFar)
}
diff --git a/ssh/transport_test.go b/ssh/transport_test.go
index 3320114..92d83ab 100644
--- a/ssh/transport_test.go
+++ b/ssh/transport_test.go
@@ -6,6 +6,8 @@
import (
"bytes"
+ "crypto/rand"
+ "encoding/binary"
"strings"
"testing"
)
@@ -67,3 +69,41 @@
}
}
}
+
+type closerBuffer struct {
+ bytes.Buffer
+}
+
+func (b *closerBuffer) Close() error {
+ return nil
+}
+
+func TestTransportMaxPacketWrite(t *testing.T) {
+ buf := &closerBuffer{}
+ tr := newTransport(buf, rand.Reader, true)
+ huge := make([]byte, maxPacket+1)
+ err := tr.writePacket(huge)
+ if err == nil {
+ t.Errorf("transport accepted write for a huge packet.")
+ }
+}
+
+func TestTransportMaxPacketReader(t *testing.T) {
+ var header [5]byte
+ huge := make([]byte, maxPacket+128)
+ binary.BigEndian.PutUint32(header[0:], uint32(len(huge)))
+ // padding.
+ header[4] = 0
+
+ buf := &closerBuffer{}
+ buf.Write(header[:])
+ buf.Write(huge)
+
+ tr := newTransport(buf, rand.Reader, true)
+ _, err := tr.readPacket()
+ if err == nil {
+ t.Errorf("transport succeeded reading huge packet.")
+ } else if !strings.Contains(err.Error(), "large") {
+ t.Errorf("got %q, should mention %q", err.Error(), "large")
+ }
+}