blob: bc4967a4275b0b77165f53dad389be692c5c5e41 [file] [log] [blame]
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin freebsd linux netbsd openbsd
package test
// functional test harness for unix.
import (
"bytes"
"crypto"
"crypto/dsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"os/user"
"path/filepath"
"sync"
"testing"
"text/template"
"time"
"code.google.com/p/go.crypto/ssh"
)
const sshd_config = `
Protocol 2
HostKey {{.Dir}}/ssh_host_rsa_key
HostKey {{.Dir}}/ssh_host_dsa_key
HostKey {{.Dir}}/ssh_host_ecdsa_key
Pidfile {{.Dir}}/sshd.pid
#UsePrivilegeSeparation no
KeyRegenerationInterval 3600
ServerKeyBits 768
SyslogFacility AUTH
LogLevel DEBUG2
LoginGraceTime 120
PermitRootLogin no
StrictModes no
RSAAuthentication yes
PubkeyAuthentication yes
AuthorizedKeysFile {{.Dir}}/authorized_keys
IgnoreRhosts yes
RhostsRSAAuthentication no
HostbasedAuthentication no
`
var (
configTmpl template.Template
rsakey *rsa.PrivateKey
serializedHostKey []byte
)
func init() {
template.Must(configTmpl.Parse(sshd_config))
block, _ := pem.Decode([]byte(testClientPrivateKey))
rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
if block == nil {
panic("pem.Decode ssh_host_rsa_key")
}
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
panic("ParsePKCS1PrivateKey: " + err.Error())
}
serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
}
type server struct {
t *testing.T
cleanup func() // executed during Shutdown
configfile string
cmd *exec.Cmd
output bytes.Buffer // holds stderr from sshd process
}
func username() string {
var username string
if user, err := user.Current(); err == nil {
username = user.Username
} else {
// user.Current() currently requires cgo. If an error is
// returned attempt to get the username from the environment.
log.Printf("user.Current: %v; falling back on $USER", err)
username = os.Getenv("USER")
}
if username == "" {
panic("Unable to get username")
}
return username
}
type storedHostKey struct {
// keys map from an algorithm string to binary key data.
keys map[string][]byte
}
func (k *storedHostKey) Add(algo string, public []byte) {
if k.keys == nil {
k.keys = map[string][]byte{}
}
k.keys[algo] = append([]byte(nil), public...)
}
func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
return errors.New("host key mismatch")
}
return nil
}
func clientConfig() *ssh.ClientConfig {
keyChecker := storedHostKey{}
keyChecker.Add("ssh-rsa", serializedHostKey)
kc := new(keychain)
kc.keys = append(kc.keys, rsakey)
config := &ssh.ClientConfig{
User: username(),
Auth: []ssh.ClientAuth{
ssh.ClientAuthKeyring(kc),
},
HostKeyChecker: &keyChecker,
}
return config
}
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
}
s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
r1, w1, err := os.Pipe()
if err != nil {
s.t.Fatal(err)
}
s.cmd.Stdout = w1
r2, w2, err := os.Pipe()
if err != nil {
s.t.Fatal(err)
}
s.cmd.Stdin = r2
s.cmd.Stderr = os.Stderr
if err := s.cmd.Start(); err != nil {
s.t.Fail()
s.Shutdown()
s.t.Fatalf("s.cmd.Start: %v", err)
}
return ssh.Client(&client{wc: w2, r: r1}, config)
}
func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
conn, err := s.TryDial(config)
if err != nil {
s.t.Fail()
s.Shutdown()
s.t.Fatalf("ssh.Client: %v", err)
}
return conn
}
func (s *server) Shutdown() {
if s.cmd != nil && s.cmd.Process != nil {
// don't check for Kill error; if it fails it's most likely
// "os: process already finished", and we don't care about that.
s.cmd.Process.Kill()
s.cmd.Wait()
}
if s.t.Failed() {
// log any output from sshd process
s.t.Logf("sshd: %q", s.output.String())
}
s.cleanup()
}
// client wraps a pair of Reader/WriteClosers to implement the
// net.Conn interface. Importantly, client also mocks the
// ability of net.Conn to support concurrent calls to Read/Write
// and Close. See golang.org/issue/5138 for more details.
type client struct {
wc io.WriteCloser
r io.Reader
sync.Mutex // protects refcount and closing
refcount int
closing bool
}
var errClosing = errors.New("use of closed network connection")
func (c *client) LocalAddr() net.Addr { return nil }
func (c *client) RemoteAddr() net.Addr { return nil }
func (c *client) SetDeadline(time.Time) error { return nil }
func (c *client) SetReadDeadline(time.Time) error { return nil }
func (c *client) SetWriteDeadline(time.Time) error { return nil }
// incref, decref are copied from the net package (see net/fd_unix.go) to
// implement the concurrent Close contract that net.Conn implementations
// from that that package provide.
func (c *client) incRef(closing bool) error {
c.Lock()
defer c.Unlock()
if c.closing {
return errClosing
}
c.refcount++
if closing {
c.closing = true
}
return nil
}
func (c *client) decRef() {
c.Lock()
defer c.Unlock()
c.refcount--
if c.closing && c.refcount == 0 {
c.wc.Close()
}
}
func (c *client) Close() error {
if err := c.incRef(true); err != nil {
return err
}
c.decRef()
return nil
}
func (c *client) Read(b []byte) (int, error) {
if err := c.incRef(false); err != nil {
return 0, err
}
defer c.decRef()
return c.r.Read(b)
}
func (c *client) Write(b []byte) (int, error) {
if err := c.incRef(false); err != nil {
return 0, err
}
defer c.decRef()
return c.wc.Write(b)
}
// newServer returns a new mock ssh server.
func newServer(t *testing.T) *server {
dir, err := ioutil.TempDir("", "sshtest")
if err != nil {
t.Fatal(err)
}
f, err := os.Create(filepath.Join(dir, "sshd_config"))
if err != nil {
t.Fatal(err)
}
err = configTmpl.Execute(f, map[string]string{
"Dir": dir,
})
if err != nil {
t.Fatal(err)
}
f.Close()
for k, v := range keys {
f, err := os.OpenFile(filepath.Join(dir, k), os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
if err != nil {
t.Fatal(err)
}
if _, err := f.Write([]byte(v)); err != nil {
t.Fatal(err)
}
f.Close()
}
return &server{
t: t,
configfile: f.Name(),
cleanup: func() {
if err := os.RemoveAll(dir); err != nil {
t.Error(err)
}
},
}
}
// keychain implements the ClientKeyring interface
type keychain struct {
keys []interface{}
}
func (k *keychain) Key(i int) (interface{}, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
switch key := k.keys[i].(type) {
case *rsa.PrivateKey:
return &key.PublicKey, nil
case *dsa.PrivateKey:
return &key.PublicKey, nil
}
panic("unknown key type")
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(data)
digest := h.Sum(nil)
switch key := k.keys[i].(type) {
case *rsa.PrivateKey:
return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
}
return nil, errors.New("ssh: unknown key type")
}
func (k *keychain) loadPEM(file string) error {
buf, err := ioutil.ReadFile(file)
if err != nil {
return err
}
block, _ := pem.Decode(buf)
if block == nil {
return errors.New("ssh: no key found")
}
r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
k.keys = append(k.keys, r)
return nil
}