exp/ssh: move common code to common.go

R=agl
CC=golang-dev
https://golang.org/cl/5132041
diff --git a/src/pkg/exp/ssh/Makefile b/src/pkg/exp/ssh/Makefile
index e8f33b7..1a100e9 100644
--- a/src/pkg/exp/ssh/Makefile
+++ b/src/pkg/exp/ssh/Makefile
@@ -6,11 +6,11 @@
 
 TARG=exp/ssh
 GOFILES=\
+	channel.go\
 	common.go\
 	messages.go\
-	server.go\
 	transport.go\
-        channel.go\
-        server_shell.go\
+	server.go\
+	server_shell.go\
 
 include ../../../Make.pkg
diff --git a/src/pkg/exp/ssh/common.go b/src/pkg/exp/ssh/common.go
index 698db60..441c5b1 100644
--- a/src/pkg/exp/ssh/common.go
+++ b/src/pkg/exp/ssh/common.go
@@ -5,7 +5,9 @@
 package ssh
 
 import (
+	"big"
 	"strconv"
+	"sync"
 )
 
 // These are string constants in the SSH protocol.
@@ -19,6 +21,32 @@
 	serviceSSH      = "ssh-connection"
 )
 
+var supportedKexAlgos = []string{kexAlgoDH14SHA1}
+var supportedHostKeyAlgos = []string{hostAlgoRSA}
+var supportedCiphers = []string{cipherAES128CTR}
+var supportedMACs = []string{macSHA196}
+var supportedCompressions = []string{compressionNone}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+	g, p *big.Int
+}
+
+// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
+// Oakley Group 14 in RFC 3526.
+var dhGroup14 *dhGroup
+
+var dhGroup14Once sync.Once
+
+func initDHGroup14() {
+	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+	dhGroup14 = &dhGroup{
+		g: new(big.Int).SetInt64(2),
+		p: p,
+	}
+}
+
 // UnexpectedMessageError results when the SSH message that we received didn't
 // match what we wanted.
 type UnexpectedMessageError struct {
@@ -38,6 +66,11 @@
 	return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
 }
 
+type handshakeMagics struct {
+	clientVersion, serverVersion []byte
+	clientKexInit, serverKexInit []byte
+}
+
 func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
 	for _, clientAlgo := range clientAlgos {
 		for _, serverAlgo := range serverAlgos {
diff --git a/src/pkg/exp/ssh/messages.go b/src/pkg/exp/ssh/messages.go
index d375eaf..bc2333e 100644
--- a/src/pkg/exp/ssh/messages.go
+++ b/src/pkg/exp/ssh/messages.go
@@ -59,6 +59,13 @@
 // in this file. The only wrinkle is that a final member of type []byte with a
 // tag of "rest" receives the remainder of a packet when unmarshaling.
 
+// See RFC 4253, section 11.1.
+type disconnectMsg struct {
+	Reason   uint32
+	Message  string
+	Language string
+}
+
 // See RFC 4253, section 7.1.
 type kexInitMsg struct {
 	Cookie                  [16]byte
@@ -137,6 +144,12 @@
 	Language string
 }
 
+// See RFC 4254, section 5.2.
+type channelData struct {
+	PeersId uint32
+	Payload []byte "rest"
+}
+
 type channelRequestMsg struct {
 	PeersId             uint32
 	Request             string
@@ -555,3 +568,60 @@
 }
 
 var bigIntType = reflect.TypeOf((*big.Int)(nil))
+
+// Decode a packet into it's corresponding message.
+func decode(packet []byte) interface{} {
+	var msg interface{}
+	switch packet[0] {
+	case msgDisconnect:
+		msg = new(disconnectMsg)
+	case msgServiceRequest:
+		msg = new(serviceRequestMsg)
+	case msgServiceAccept:
+		msg = new(serviceAcceptMsg)
+	case msgKexInit:
+		msg = new(kexInitMsg)
+	case msgKexDHInit:
+		msg = new(kexDHInitMsg)
+	case msgKexDHReply:
+		msg = new(kexDHReplyMsg)
+	case msgUserAuthRequest:
+		msg = new(userAuthRequestMsg)
+	case msgUserAuthFailure:
+		msg = new(userAuthFailureMsg)
+	case msgUserAuthPubKeyOk:
+		msg = new(userAuthPubKeyOkMsg)
+	case msgGlobalRequest:
+		msg = new(globalRequestMsg)
+	case msgRequestSuccess:
+		msg = new(channelRequestSuccessMsg)
+	case msgRequestFailure:
+		msg = new(channelRequestFailureMsg)
+	case msgChannelOpen:
+		msg = new(channelOpenMsg)
+	case msgChannelOpenConfirm:
+		msg = new(channelOpenConfirmMsg)
+	case msgChannelOpenFailure:
+		msg = new(channelOpenFailureMsg)
+	case msgChannelWindowAdjust:
+		msg = new(windowAdjustMsg)
+	case msgChannelData:
+		msg = new(channelData)
+	case msgChannelEOF:
+		msg = new(channelEOFMsg)
+	case msgChannelClose:
+		msg = new(channelCloseMsg)
+	case msgChannelRequest:
+		msg = new(channelRequestMsg)
+	case msgChannelSuccess:
+		msg = new(channelRequestSuccessMsg)
+	case msgChannelFailure:
+		msg = new(channelRequestFailureMsg)
+	default:
+		return UnexpectedMessageError{0, packet[0]}
+	}
+	if err := unmarshal(msg, packet, packet[0]); err != nil {
+		return err
+	}
+	return msg
+}
diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go
index bc0af13..9e7a255 100644
--- a/src/pkg/exp/ssh/server.go
+++ b/src/pkg/exp/ssh/server.go
@@ -6,7 +6,6 @@
 
 import (
 	"big"
-	"bufio"
 	"bytes"
 	"crypto"
 	"crypto/rand"
@@ -19,12 +18,6 @@
 	"sync"
 )
 
-var supportedKexAlgos = []string{kexAlgoDH14SHA1}
-var supportedHostKeyAlgos = []string{hostAlgoRSA}
-var supportedCiphers = []string{cipherAES128CTR}
-var supportedMACs = []string{macSHA196}
-var supportedCompressions = []string{compressionNone}
-
 // Server represents an SSH server. A Server may have several ServerConnections.
 type Server struct {
 	rsa           *rsa.PrivateKey
@@ -146,31 +139,6 @@
 	cachedPubKeys []cachedPubKey
 }
 
-// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
-type dhGroup struct {
-	g, p *big.Int
-}
-
-// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
-// Oakley Group 14 in RFC 3526.
-var dhGroup14 *dhGroup
-
-var dhGroup14Once sync.Once
-
-func initDHGroup14() {
-	p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
-
-	dhGroup14 = &dhGroup{
-		g: new(big.Int).SetInt64(2),
-		p: p,
-	}
-}
-
-type handshakeMagics struct {
-	clientVersion, serverVersion []byte
-	clientKexInit, serverKexInit []byte
-}
-
 // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
 // returned values are given the same names as in RFC 4253, section 8.
 func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
@@ -292,18 +260,7 @@
 // Handshake performs an SSH transport and client authentication on the given ServerConnection.
 func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
 	var magics handshakeMagics
-	s.transport = &transport{
-		reader: reader{
-			Reader: bufio.NewReader(conn),
-		},
-		writer: writer{
-			Writer: bufio.NewWriter(conn),
-			rand:   rand.Reader,
-		},
-		Close: func() os.Error {
-			return conn.Close()
-		},
-	}
+	s.transport = newTransport(conn)
 
 	if _, err := conn.Write(serverVersion); err != nil {
 		return err
@@ -612,19 +569,14 @@
 			return nil, err
 		}
 
-		switch packet[0] {
-		case msgChannelOpen:
-			var chanOpen channelOpenMsg
-			if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
-				return nil, err
-			}
-
+		switch msg := decode(packet).(type) {
+		case *channelOpenMsg:
 			c := new(channel)
-			c.chanType = chanOpen.ChanType
-			c.theirId = chanOpen.PeersId
-			c.theirWindow = chanOpen.PeersWindow
-			c.maxPacketSize = chanOpen.MaxPacketSize
-			c.extraData = chanOpen.TypeSpecificData
+			c.chanType = msg.ChanType
+			c.theirId = msg.PeersId
+			c.theirWindow = msg.PeersWindow
+			c.maxPacketSize = msg.MaxPacketSize
+			c.extraData = msg.TypeSpecificData
 			c.myWindow = defaultWindowSize
 			c.serverConn = s
 			c.cond = sync.NewCond(&c.lock)
@@ -637,74 +589,53 @@
 			s.lock.Unlock()
 			return c, nil
 
-		case msgChannelRequest:
-			var chanRequest channelRequestMsg
-			if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
-				return nil, err
-			}
-
+		case *channelRequestMsg:
 			s.lock.Lock()
-			c, ok := s.channels[chanRequest.PeersId]
+			c, ok := s.channels[msg.PeersId]
 			if !ok {
 				continue
 			}
-			c.handlePacket(&chanRequest)
+			c.handlePacket(msg)
 			s.lock.Unlock()
 
-		case msgChannelData:
-			if len(packet) < 5 {
-				return nil, ParseError{msgChannelData}
-			}
-			chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
-
+		case *channelData:
 			s.lock.Lock()
-			c, ok := s.channels[chanId]
+			c, ok := s.channels[msg.PeersId]
 			if !ok {
 				continue
 			}
-			c.handleData(packet[9:])
+			c.handleData(msg.Payload)
 			s.lock.Unlock()
 
-		case msgChannelEOF:
-			var eofMsg channelEOFMsg
-			if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
-				return nil, err
-			}
-
+		case *channelEOFMsg:
 			s.lock.Lock()
-			c, ok := s.channels[eofMsg.PeersId]
+			c, ok := s.channels[msg.PeersId]
 			if !ok {
 				continue
 			}
-			c.handlePacket(&eofMsg)
+			c.handlePacket(msg)
 			s.lock.Unlock()
 
-		case msgChannelClose:
-			var closeMsg channelCloseMsg
-			if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
-				return nil, err
-			}
-
+		case *channelCloseMsg:
 			s.lock.Lock()
-			c, ok := s.channels[closeMsg.PeersId]
+			c, ok := s.channels[msg.PeersId]
 			if !ok {
 				continue
 			}
-			c.handlePacket(&closeMsg)
+			c.handlePacket(msg)
 			s.lock.Unlock()
 
-		case msgGlobalRequest:
-			var request globalRequestMsg
-			if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
-				return nil, err
-			}
-
-			if request.WantReply {
+		case *globalRequestMsg:
+			if msg.WantReply {
 				if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
 					return nil, err
 				}
 			}
 
+		case UnexpectedMessageError:
+			return nil, msg
+		case *disconnectMsg:
+			return nil, os.EOF
 		default:
 			// Unknown message. Ignore.
 		}
diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go
index 5a474a9..77660a2 100644
--- a/src/pkg/exp/ssh/transport.go
+++ b/src/pkg/exp/ssh/transport.go
@@ -10,10 +10,13 @@
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/hmac"
+	"crypto/rand"
 	"crypto/subtle"
 	"hash"
 	"io"
+	"net"
 	"os"
+	"sync"
 )
 
 const (
@@ -29,7 +32,8 @@
 	macAlgo         string
 	compressionAlgo string
 
-	Close func() os.Error
+	Close      func() os.Error
+	RemoteAddr func() net.Addr
 }
 
 // reader represents the incoming connection state.
@@ -40,6 +44,7 @@
 
 // writer represnts the outgoing connection state.
 type writer struct {
+	*sync.Mutex // protects writer.Writer from concurrent writes
 	*bufio.Writer
 	paddingMultiple int
 	rand            io.Reader
@@ -126,6 +131,9 @@
 
 // Encrypt and send a packet of data to the remote peer.
 func (w *writer) writePacket(packet []byte) os.Error {
+	w.Mutex.Lock()
+	defer w.Mutex.Unlock()
+
 	paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
 	if paddingLength < 4 {
 		paddingLength += paddingMultiple
@@ -190,6 +198,31 @@
 	return err
 }
 
+// Send a message to the remote peer
+func (t *transport) sendMessage(typ uint8, msg interface{}) os.Error {
+	packet := marshal(typ, msg)
+	return t.writePacket(packet)
+}
+
+func newTransport(conn net.Conn) *transport {
+	return &transport{
+		reader: reader{
+			Reader: bufio.NewReader(conn),
+		},
+		writer: writer{
+			Writer: bufio.NewWriter(conn),
+			rand:   rand.Reader,
+			Mutex:  new(sync.Mutex),
+		},
+		Close: func() os.Error {
+			return conn.Close()
+		},
+		RemoteAddr: func() net.Addr {
+			return conn.RemoteAddr()
+		},
+	}
+}
+
 type direction struct {
 	ivTag     []byte
 	keyTag    []byte