| // Copyright 2012 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package agent |
| |
| import ( |
| "bytes" |
| "crypto/rand" |
| "errors" |
| "io" |
| "net" |
| "os" |
| "os/exec" |
| "path/filepath" |
| "runtime" |
| "strconv" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "golang.org/x/crypto/ssh" |
| ) |
| |
| // startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it. |
| func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) { |
| if testing.Short() { |
| // ssh-agent is not always available, and the key |
| // types supported vary by platform. |
| t.Skip("skipping test due to -short") |
| } |
| |
| bin, err := exec.LookPath("ssh-agent") |
| if err != nil { |
| t.Skip("could not find ssh-agent") |
| } |
| |
| cmd := exec.Command(bin, "-s") |
| cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior. |
| cmd.Stderr = new(bytes.Buffer) |
| out, err := cmd.Output() |
| if err != nil { |
| t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr) |
| } |
| |
| // Output looks like: |
| // |
| // SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; |
| // SSH_AGENT_PID=15542; export SSH_AGENT_PID; |
| // echo Agent pid 15542; |
| |
| fields := bytes.Split(out, []byte(";")) |
| line := bytes.SplitN(fields[0], []byte("="), 2) |
| line[0] = bytes.TrimLeft(line[0], "\n") |
| if string(line[0]) != "SSH_AUTH_SOCK" { |
| t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) |
| } |
| socket = string(line[1]) |
| |
| line = bytes.SplitN(fields[2], []byte("="), 2) |
| line[0] = bytes.TrimLeft(line[0], "\n") |
| if string(line[0]) != "SSH_AGENT_PID" { |
| t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) |
| } |
| pidStr := line[1] |
| pid, err := strconv.Atoi(string(pidStr)) |
| if err != nil { |
| t.Fatalf("Atoi(%q): %v", pidStr, err) |
| } |
| |
| conn, err := net.Dial("unix", string(socket)) |
| if err != nil { |
| t.Fatalf("net.Dial: %v", err) |
| } |
| |
| ac := NewClient(conn) |
| return ac, socket, func() { |
| proc, _ := os.FindProcess(pid) |
| if proc != nil { |
| proc.Kill() |
| } |
| conn.Close() |
| os.RemoveAll(filepath.Dir(socket)) |
| } |
| } |
| |
| func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) { |
| c1, c2, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| go ServeAgent(agent, c2) |
| |
| return NewClient(c1), func() { |
| c1.Close() |
| c2.Close() |
| } |
| } |
| |
| // startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client. |
| func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) { |
| return startAgent(t, NewKeyring()) |
| } |
| |
| func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
| agent, _, cleanup := startOpenSSHAgent(t) |
| defer cleanup() |
| |
| testAgentInterface(t, agent, key, cert, lifetimeSecs) |
| } |
| |
| func testKeyringAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
| agent, cleanup := startKeyringAgent(t) |
| defer cleanup() |
| |
| testAgentInterface(t, agent, key, cert, lifetimeSecs) |
| } |
| |
| func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
| signer, err := ssh.NewSignerFromKey(key) |
| if err != nil { |
| t.Fatalf("NewSignerFromKey(%T): %v", key, err) |
| } |
| // The agent should start up empty. |
| if keys, err := agent.List(); err != nil { |
| t.Fatalf("RequestIdentities: %v", err) |
| } else if len(keys) > 0 { |
| t.Fatalf("got %d keys, want 0: %v", len(keys), keys) |
| } |
| |
| // Attempt to insert the key, with certificate if specified. |
| var pubKey ssh.PublicKey |
| if cert != nil { |
| err = agent.Add(AddedKey{ |
| PrivateKey: key, |
| Certificate: cert, |
| Comment: "comment", |
| LifetimeSecs: lifetimeSecs, |
| }) |
| pubKey = cert |
| } else { |
| err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) |
| pubKey = signer.PublicKey() |
| } |
| if err != nil { |
| t.Fatalf("insert(%T): %v", key, err) |
| } |
| |
| // Did the key get inserted successfully? |
| if keys, err := agent.List(); err != nil { |
| t.Fatalf("List: %v", err) |
| } else if len(keys) != 1 { |
| t.Fatalf("got %v, want 1 key", keys) |
| } else if keys[0].Comment != "comment" { |
| t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") |
| } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { |
| t.Fatalf("key mismatch") |
| } |
| |
| // Can the agent make a valid signature? |
| data := []byte("hello") |
| sig, err := agent.Sign(pubKey, data) |
| if err != nil { |
| t.Fatalf("Sign(%s): %v", pubKey.Type(), err) |
| } |
| |
| if err := pubKey.Verify(data, sig); err != nil { |
| t.Fatalf("Verify(%s): %v", pubKey.Type(), err) |
| } |
| |
| // For tests on RSA keys, try signing with SHA-256 and SHA-512 flags |
| if pubKey.Type() == "ssh-rsa" { |
| sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) { |
| sig, err = agent.SignWithFlags(pubKey, data, flag) |
| if err != nil { |
| t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err) |
| } |
| if sig.Format != expectedSigFormat { |
| t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat) |
| } |
| if err := pubKey.Verify(data, sig); err != nil { |
| t.Fatalf("Verify(%s): %v", pubKey.Type(), err) |
| } |
| } |
| sshFlagTest(0, ssh.SigAlgoRSA) |
| sshFlagTest(SignatureFlagRsaSha256, ssh.SigAlgoRSASHA2256) |
| sshFlagTest(SignatureFlagRsaSha512, ssh.SigAlgoRSASHA2512) |
| } |
| |
| // If the key has a lifetime, is it removed when it should be? |
| if lifetimeSecs > 0 { |
| time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond) |
| keys, err := agent.List() |
| if err != nil { |
| t.Fatalf("List: %v", err) |
| } |
| if len(keys) > 0 { |
| t.Fatalf("key not expired") |
| } |
| } |
| |
| } |
| |
| func TestMalformedRequests(t *testing.T) { |
| keyringAgent := NewKeyring() |
| listener, err := netListener() |
| if err != nil { |
| t.Fatalf("netListener: %v", err) |
| } |
| defer listener.Close() |
| |
| testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) { |
| var wg sync.WaitGroup |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| c, err := listener.Accept() |
| if err != nil { |
| t.Errorf("listener.Accept: %v", err) |
| return |
| } |
| defer c.Close() |
| |
| err = ServeAgent(keyringAgent, c) |
| if err == nil { |
| t.Error("ServeAgent should have returned an error to malformed input") |
| } else { |
| if (err != io.EOF) != wantServerErr { |
| t.Errorf("ServeAgent returned expected error: %v", err) |
| } |
| } |
| }() |
| |
| c, err := net.Dial("tcp", listener.Addr().String()) |
| if err != nil { |
| t.Fatalf("net.Dial: %v", err) |
| } |
| _, err = c.Write(requestBytes) |
| if err != nil { |
| t.Errorf("Unexpected error writing raw bytes on connection: %v", err) |
| } |
| c.Close() |
| wg.Wait() |
| } |
| |
| var testCases = []struct { |
| name string |
| requestBytes []byte |
| wantServerErr bool |
| }{ |
| {"Empty request", []byte{}, false}, |
| {"Short header", []byte{0x00}, true}, |
| {"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true}, |
| {"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false}, |
| } |
| for _, tc := range testCases { |
| t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) }) |
| } |
| } |
| |
| func TestAgent(t *testing.T) { |
| for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} { |
| testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0) |
| testKeyringAgent(t, testPrivateKeys[keyType], nil, 0) |
| } |
| } |
| |
| func TestCert(t *testing.T) { |
| cert := &ssh.Certificate{ |
| Key: testPublicKeys["rsa"], |
| ValidBefore: ssh.CertTimeInfinity, |
| CertType: ssh.UserCert, |
| } |
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
| |
| testOpenSSHAgent(t, testPrivateKeys["rsa"], cert, 0) |
| testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0) |
| } |
| |
| // netListener creates a localhost network listener. |
| func netListener() (net.Listener, error) { |
| listener, err := net.Listen("tcp", "127.0.0.1:0") |
| if err != nil { |
| listener, err = net.Listen("tcp", "[::1]:0") |
| if err != nil { |
| return nil, err |
| } |
| } |
| return listener, nil |
| } |
| |
| // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and |
| // therefore is buffered (net.Pipe deadlocks if both sides start with |
| // a write.) |
| func netPipe() (net.Conn, net.Conn, error) { |
| listener, err := netListener() |
| if err != nil { |
| return nil, nil, err |
| } |
| defer listener.Close() |
| c1, err := net.Dial("tcp", listener.Addr().String()) |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| c2, err := listener.Accept() |
| if err != nil { |
| c1.Close() |
| return nil, nil, err |
| } |
| |
| return c1, c2, nil |
| } |
| |
| func TestServerResponseTooLarge(t *testing.T) { |
| a, b, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| done := make(chan struct{}) |
| defer func() { <-done }() |
| |
| defer a.Close() |
| defer b.Close() |
| |
| var response identitiesAnswerAgentMsg |
| response.NumKeys = 1 |
| response.Keys = make([]byte, maxAgentResponseBytes+1) |
| |
| agent := NewClient(a) |
| go func() { |
| defer close(done) |
| n, err := b.Write(ssh.Marshal(response)) |
| if n < 4 { |
| if runtime.GOOS == "plan9" { |
| if e1, ok := err.(*net.OpError); ok { |
| if e2, ok := e1.Err.(*os.PathError); ok { |
| switch e2.Err.Error() { |
| case "Hangup", "i/o on hungup channel": |
| // syscall.Pwrite returns -1 in this case even when some data did get written. |
| return |
| } |
| } |
| } |
| } |
| t.Errorf("At least 4 bytes (the response size) should have been successfully written: %d < 4: %v", n, err) |
| } |
| }() |
| _, err = agent.List() |
| if err == nil { |
| t.Fatal("Did not get error result") |
| } |
| if err.Error() != "agent: client error: response too large" { |
| t.Fatal("Did not get expected error result") |
| } |
| } |
| |
| func TestAuth(t *testing.T) { |
| agent, _, cleanup := startOpenSSHAgent(t) |
| defer cleanup() |
| |
| a, b, err := netPipe() |
| if err != nil { |
| t.Fatalf("netPipe: %v", err) |
| } |
| |
| defer a.Close() |
| defer b.Close() |
| |
| if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { |
| t.Errorf("Add: %v", err) |
| } |
| |
| serverConf := ssh.ServerConfig{} |
| serverConf.AddHostKey(testSigners["rsa"]) |
| serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
| if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
| return nil, nil |
| } |
| |
| return nil, errors.New("pubkey rejected") |
| } |
| |
| go func() { |
| conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
| if err != nil { |
| t.Fatalf("Server: %v", err) |
| } |
| conn.Close() |
| }() |
| |
| conf := ssh.ClientConfig{ |
| HostKeyCallback: ssh.InsecureIgnoreHostKey(), |
| } |
| conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) |
| conn, _, _, err := ssh.NewClientConn(b, "", &conf) |
| if err != nil { |
| t.Fatalf("NewClientConn: %v", err) |
| } |
| conn.Close() |
| } |
| |
| func TestLockOpenSSHAgent(t *testing.T) { |
| agent, _, cleanup := startOpenSSHAgent(t) |
| defer cleanup() |
| testLockAgent(agent, t) |
| } |
| |
| func TestLockKeyringAgent(t *testing.T) { |
| agent, cleanup := startKeyringAgent(t) |
| defer cleanup() |
| testLockAgent(agent, t) |
| } |
| |
| func testLockAgent(agent Agent, t *testing.T) { |
| if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { |
| t.Errorf("Add: %v", err) |
| } |
| if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { |
| t.Errorf("Add: %v", err) |
| } |
| if keys, err := agent.List(); err != nil { |
| t.Errorf("List: %v", err) |
| } else if len(keys) != 2 { |
| t.Errorf("Want 2 keys, got %v", keys) |
| } |
| |
| passphrase := []byte("secret") |
| if err := agent.Lock(passphrase); err != nil { |
| t.Errorf("Lock: %v", err) |
| } |
| |
| if keys, err := agent.List(); err != nil { |
| t.Errorf("List: %v", err) |
| } else if len(keys) != 0 { |
| t.Errorf("Want 0 keys, got %v", keys) |
| } |
| |
| signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) |
| if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { |
| t.Fatalf("Sign did not fail") |
| } |
| |
| if err := agent.Remove(signer.PublicKey()); err == nil { |
| t.Fatalf("Remove did not fail") |
| } |
| |
| if err := agent.RemoveAll(); err == nil { |
| t.Fatalf("RemoveAll did not fail") |
| } |
| |
| if err := agent.Unlock(nil); err == nil { |
| t.Errorf("Unlock with wrong passphrase succeeded") |
| } |
| if err := agent.Unlock(passphrase); err != nil { |
| t.Errorf("Unlock: %v", err) |
| } |
| |
| if err := agent.Remove(signer.PublicKey()); err != nil { |
| t.Fatalf("Remove: %v", err) |
| } |
| |
| if keys, err := agent.List(); err != nil { |
| t.Errorf("List: %v", err) |
| } else if len(keys) != 1 { |
| t.Errorf("Want 1 keys, got %v", keys) |
| } |
| } |
| |
| func testOpenSSHAgentLifetime(t *testing.T) { |
| agent, _, cleanup := startOpenSSHAgent(t) |
| defer cleanup() |
| testAgentLifetime(t, agent) |
| } |
| |
| func testKeyringAgentLifetime(t *testing.T) { |
| agent, cleanup := startKeyringAgent(t) |
| defer cleanup() |
| testAgentLifetime(t, agent) |
| } |
| |
| func testAgentLifetime(t *testing.T, agent Agent) { |
| for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { |
| // Add private keys to the agent. |
| err := agent.Add(AddedKey{ |
| PrivateKey: testPrivateKeys[keyType], |
| Comment: "comment", |
| LifetimeSecs: 1, |
| }) |
| if err != nil { |
| t.Fatalf("add: %v", err) |
| } |
| // Add certs to the agent. |
| cert := &ssh.Certificate{ |
| Key: testPublicKeys[keyType], |
| ValidBefore: ssh.CertTimeInfinity, |
| CertType: ssh.UserCert, |
| } |
| cert.SignCert(rand.Reader, testSigners[keyType]) |
| err = agent.Add(AddedKey{ |
| PrivateKey: testPrivateKeys[keyType], |
| Certificate: cert, |
| Comment: "comment", |
| LifetimeSecs: 1, |
| }) |
| if err != nil { |
| t.Fatalf("add: %v", err) |
| } |
| } |
| time.Sleep(1100 * time.Millisecond) |
| if keys, err := agent.List(); err != nil { |
| t.Errorf("List: %v", err) |
| } else if len(keys) != 0 { |
| t.Errorf("Want 0 keys, got %v", len(keys)) |
| } |
| } |
| |
| type keyringExtended struct { |
| *keyring |
| } |
| |
| func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) { |
| if extensionType != "my-extension@example.com" { |
| return []byte{agentExtensionFailure}, nil |
| } |
| return append([]byte{agentSuccess}, contents...), nil |
| } |
| |
| func TestAgentExtensions(t *testing.T) { |
| agent, _, cleanup := startOpenSSHAgent(t) |
| defer cleanup() |
| _, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02}) |
| if err == nil { |
| t.Fatal("should have gotten agent extension failure") |
| } |
| |
| agent, cleanup = startAgent(t, &keyringExtended{}) |
| defer cleanup() |
| result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02}) |
| if err != nil { |
| t.Fatalf("agent extension failure: %v", err) |
| } |
| if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) { |
| t.Fatalf("agent extension result invalid: %v", result) |
| } |
| |
| _, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02}) |
| if err == nil { |
| t.Fatal("should have gotten agent extension failure") |
| } |
| } |