go.crypto/ssh: separate kex algorithms into kexAlgorithm class.
Adds readPacket() to conn, and renames conn to packetConn.
Key exchanges operate on packetConn, so they can be
unittested.
R=agl, jpsugar, dave
CC=golang-dev
https://golang.org/cl/13352055
diff --git a/ssh/channel.go b/ssh/channel.go
index be7b19f..336f7f7 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -78,7 +78,7 @@
)
type channel struct {
- conn // the underlying transport
+ packetConn // the underlying transport
localId, remoteId uint32
remoteWin window
maxPacket uint32
@@ -102,7 +102,7 @@
// sendClose informs the remote side of our intent to close the channel.
func (c *channel) sendClose() error {
- return c.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
+ return c.packetConn.writePacket(marshal(msgChannelClose, channelCloseMsg{
PeersId: c.remoteId,
}))
}
@@ -124,7 +124,7 @@
if uint32(len(b)) > c.maxPacket {
return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
}
- return c.conn.writePacket(b)
+ return c.packetConn.writePacket(b)
}
func (c *channel) closed() bool {
@@ -447,12 +447,12 @@
// newClientChan returns a partially constructed *clientChan
// using the local id provided. To be usable clientChan.remoteId
// needs to be assigned once known.
-func newClientChan(cc conn, id uint32) *clientChan {
+func newClientChan(cc packetConn, id uint32) *clientChan {
c := &clientChan{
channel: channel{
- conn: cc,
- localId: id,
- remoteWin: window{Cond: newCond()},
+ packetConn: cc,
+ localId: id,
+ remoteWin: window{Cond: newCond()},
},
msg: make(chan interface{}, 16),
}
diff --git a/ssh/client.go b/ssh/client.go
index 6a5ec24..99e8377 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -5,15 +5,11 @@
package ssh
import (
- "crypto"
- "crypto/ecdsa"
- "crypto/elliptic"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
- "math/big"
"net"
"sync"
)
@@ -63,17 +59,14 @@
// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() error {
- var magics handshakeMagics
-
- var version []byte
+ var myVersion []byte
if len(c.config.ClientVersion) > 0 {
- version = []byte(c.config.ClientVersion)
+ myVersion = []byte(c.config.ClientVersion)
} else {
- version = clientVersion
+ myVersion = clientVersion
}
- magics.clientVersion = version
- version = append(version, '\r', '\n')
- if _, err := c.Write(version); err != nil {
+
+ if _, err := c.Write(append(myVersion, '\r', '\n')); err != nil {
return err
}
if err := c.Flush(); err != nil {
@@ -81,12 +74,12 @@
}
// read remote server version
- version, err := readVersion(c)
+ serverVersion, err := readVersion(c)
if err != nil {
return err
}
- magics.serverVersion = version
- c.serverVersion = string(version)
+ c.serverVersion = string(serverVersion)
+
clientKexInit := kexInitMsg{
KexAlgos: c.config.Crypto.kexes(),
ServerHostKeyAlgos: supportedHostKeyAlgos,
@@ -98,8 +91,6 @@
CompressionServerClient: supportedCompressions,
}
kexInitPacket := marshal(msgKexInit, clientKexInit)
- magics.clientKexInit = kexInitPacket
-
if err := c.writePacket(kexInitPacket); err != nil {
return err
}
@@ -108,8 +99,6 @@
return err
}
- magics.serverKexInit = packet
-
var serverKexInit kexInitMsg
if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
return err
@@ -128,23 +117,18 @@
}
}
- var result *kexResult
- switch kexAlgo {
- case kexAlgoECDH256:
- result, err = c.kexECDH(elliptic.P256(), &magics, hostKeyAlgo)
- case kexAlgoECDH384:
- result, err = c.kexECDH(elliptic.P384(), &magics, hostKeyAlgo)
- case kexAlgoECDH521:
- result, err = c.kexECDH(elliptic.P521(), &magics, hostKeyAlgo)
- case kexAlgoDH14SHA1:
- dhGroup14Once.Do(initDHGroup14)
- result, err = c.kexDH(crypto.SHA1, dhGroup14, &magics, hostKeyAlgo)
- case kexAlgoDH1SHA1:
- dhGroup1Once.Do(initDHGroup1)
- result, err = c.kexDH(crypto.SHA1, dhGroup1, &magics, hostKeyAlgo)
- default:
- err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+ kex, ok := kexAlgoMap[kexAlgo]
+ if !ok {
+ return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
}
+
+ magics := handshakeMagics{
+ clientVersion: myVersion,
+ serverVersion: serverVersion,
+ clientKexInit: kexInitPacket,
+ serverKexInit: packet,
+ }
+ result, err := kex.Client(c, c.config.rand(), &magics)
if err != nil {
return err
}
@@ -164,7 +148,8 @@
if err = c.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
- if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, result.Hash); err != nil {
+
+ if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
return err
}
if packet, err = c.readPacket(); err != nil {
@@ -173,72 +158,12 @@
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
- if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, result.Hash); err != nil {
+ if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
return err
}
return c.authenticate(result.H)
}
-// kexECDH performs Elliptic Curve Diffie-Hellman key exchange as
-// described in RFC 5656, section 4.
-func (c *ClientConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) {
- ephKey, err := ecdsa.GenerateKey(curve, c.config.rand())
- if err != nil {
- return nil, err
- }
-
- kexInit := kexECDHInitMsg{
- ClientPubKey: elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
- }
-
- serialized := marshal(msgKexECDHInit, kexInit)
- if err := c.writePacket(serialized); err != nil {
- return nil, err
- }
-
- packet, err := c.readPacket()
- if err != nil {
- return nil, err
- }
-
- var reply kexECDHReplyMsg
- if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
- return nil, err
- }
-
- x, y := elliptic.Unmarshal(curve, reply.EphemeralPubKey)
- if x == nil {
- return nil, errors.New("ssh: elliptic.Unmarshal failure")
- }
- if !validateECPublicKey(curve, x, y) {
- return nil, errors.New("ssh: ephemeral server key not on curve")
- }
-
- // generate shared secret
- secret, _ := curve.ScalarMult(x, y, ephKey.D.Bytes())
-
- hashFunc := ecHash(curve)
- h := hashFunc.New()
- writeString(h, magics.clientVersion)
- writeString(h, magics.serverVersion)
- writeString(h, magics.clientKexInit)
- writeString(h, magics.serverKexInit)
- writeString(h, reply.HostKey)
- writeString(h, kexInit.ClientPubKey)
- writeString(h, reply.EphemeralPubKey)
- K := make([]byte, intLength(secret))
- marshalInt(K, secret)
- h.Write(K)
-
- return &kexResult{
- H: h.Sum(nil),
- K: K,
- HostKey: reply.HostKey,
- Signature: reply.Signature,
- Hash: hashFunc,
- }, nil
-}
-
// Verify the host key obtained in the key exchange.
func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error {
hostKey, rest, ok := ParsePublicKey(hostKeyBytes)
@@ -260,74 +185,6 @@
return nil
}
-// kexResult captures the outcome of a key exchange.
-type kexResult struct {
- // Session hash. See also RFC 4253, section 8.
- H []byte
-
- // Shared secret. See also RFC 4253, section 8.
- K []byte
-
- // Host key as hashed into H
- HostKey []byte
-
- // Signature of H
- Signature []byte
-
- // Hash function that was used.
- Hash crypto.Hash
-}
-
-// kexDH performs Diffie-Hellman key agreement on a ClientConn.
-func (c *ClientConn) kexDH(hashFunc crypto.Hash, group *dhGroup, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) {
- x, err := rand.Int(c.config.rand(), group.p)
- if err != nil {
- return nil, err
- }
- X := new(big.Int).Exp(group.g, x, group.p)
- kexDHInit := kexDHInitMsg{
- X: X,
- }
- if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
- return nil, err
- }
-
- packet, err := c.readPacket()
- if err != nil {
- return nil, err
- }
-
- var kexDHReply kexDHReplyMsg
- if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
- return nil, err
- }
-
- kInt, err := group.diffieHellman(kexDHReply.Y, x)
- if err != nil {
- return nil, err
- }
-
- h := hashFunc.New()
- writeString(h, magics.clientVersion)
- writeString(h, magics.serverVersion)
- writeString(h, magics.clientKexInit)
- writeString(h, magics.serverKexInit)
- writeString(h, kexDHReply.HostKey)
- writeInt(h, X)
- writeInt(h, kexDHReply.Y)
- K := make([]byte, intLength(kInt))
- marshalInt(K, kInt)
- h.Write(K)
-
- return &kexResult{
- H: h.Sum(nil),
- K: K,
- HostKey: kexDHReply.HostKey,
- Signature: kexDHReply.Signature,
- Hash: hashFunc,
- }, nil
-}
-
// mainLoop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
@@ -633,18 +490,18 @@
}
// Allocate a new ClientChan with the next avail local id.
-func (c *chanList) newChan(t *transport) *clientChan {
+func (c *chanList) newChan(p packetConn) *clientChan {
c.Lock()
defer c.Unlock()
for i := range c.chans {
if c.chans[i] == nil {
- ch := newClientChan(t, uint32(i))
+ ch := newClientChan(p, uint32(i))
c.chans[i] = ch
return ch
}
}
i := len(c.chans)
- ch := newClientChan(t, uint32(i))
+ ch := newClientChan(p, uint32(i))
c.chans = append(c.chans, ch)
return ch
}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 31a5e09..d2ed48e 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -81,7 +81,7 @@
// Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative
// method names is returned.
- auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
+ auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name.
method() string
@@ -90,8 +90,8 @@
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
-func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
- if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
+func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+ if err := c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
@@ -99,7 +99,7 @@
return false, nil, err
}
- return handleAuthResponse(t)
+ return handleAuthResponse(c)
}
func (n *noneAuth) method() string {
@@ -111,7 +111,7 @@
ClientPassword
}
-func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+func (p *passwordAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct {
User string
Service string
@@ -125,7 +125,7 @@
return false, nil, err
}
- if err := t.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
+ if err := c.writePacket(marshal(msgUserAuthRequest, passwordAuthMsg{
User: user,
Service: serviceSSH,
Method: "password",
@@ -135,7 +135,7 @@
return false, nil, err
}
- return handleAuthResponse(t)
+ return handleAuthResponse(c)
}
func (p *passwordAuth) method() string {
@@ -181,7 +181,7 @@
Sig []byte `ssh:"rest"`
}
-func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+func (p *publickeyAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
@@ -200,7 +200,7 @@
break
}
- if ok, err := p.validateKey(key, user, t); ok {
+ if ok, err := p.validateKey(key, user, c); ok {
validKeys[index] = key
} else {
if err != nil {
@@ -237,10 +237,10 @@
Sig: sig,
}
p := marshal(msgUserAuthRequest, msg)
- if err := t.writePacket(p); err != nil {
+ if err := c.writePacket(p); err != nil {
return false, nil, err
}
- success, methods, err := handleAuthResponse(t)
+ success, methods, err := handleAuthResponse(c)
if err != nil {
return false, nil, err
}
@@ -252,7 +252,7 @@
}
// validateKey validates the key provided it is acceptable to the server.
-func (p *publickeyAuth) validateKey(key PublicKey, user string, t *transport) (bool, error) {
+func (p *publickeyAuth) validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubkey := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
msg := publickeyAuthMsg{
@@ -263,19 +263,19 @@
Algoname: algoname,
Pubkey: string(pubkey),
}
- if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
+ if err := c.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
return false, err
}
- return p.confirmKeyAck(key, t)
+ return p.confirmKeyAck(key, c)
}
-func (p *publickeyAuth) confirmKeyAck(key PublicKey, t *transport) (bool, error) {
+func (p *publickeyAuth) confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
pubkey := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
for {
- packet, err := t.readPacket()
+ packet, err := c.readPacket()
if err != nil {
return false, err
}
@@ -312,9 +312,9 @@
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
-func handleAuthResponse(t *transport) (bool, []string, error) {
+func handleAuthResponse(c packetConn) (bool, []string, error) {
for {
- packet, err := t.readPacket()
+ packet, err := c.readPacket()
if err != nil {
return false, nil, err
}
@@ -411,11 +411,11 @@
ClientKeyboardInteractive
}
-func (c *keyboardInteractiveAuth) method() string {
+func (k *keyboardInteractiveAuth) method() string {
return "keyboard-interactive"
}
-func (c *keyboardInteractiveAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
+func (k *keyboardInteractiveAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
type initiateMsg struct {
User string
Service string
@@ -424,7 +424,7 @@
Submethods string
}
- if err := t.writePacket(marshal(msgUserAuthRequest, initiateMsg{
+ if err := c.writePacket(marshal(msgUserAuthRequest, initiateMsg{
User: user,
Service: serviceSSH,
Method: "keyboard-interactive",
@@ -433,7 +433,7 @@
}
for {
- packet, err := t.readPacket()
+ packet, err := c.readPacket()
if err != nil {
return false, nil, err
}
@@ -480,7 +480,7 @@
return false, nil, fmt.Errorf("ssh: junk following message %q", rest)
}
- answers, err := c.Challenge(msg.User, msg.Instruction, prompts, echos)
+ answers, err := k.Challenge(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
}
@@ -501,7 +501,7 @@
p = marshalString(p, []byte(a))
}
- if err := t.writePacket(serialized); err != nil {
+ if err := c.writePacket(serialized); err != nil {
return false, nil, err
}
}
diff --git a/ssh/common.go b/ssh/common.go
index 7e6e5dc..e640000 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -6,9 +6,7 @@
import (
"crypto"
- "errors"
"fmt"
- "math/big"
"sync"
_ "crypto/sha1"
@@ -18,11 +16,6 @@
// These are string constants in the SSH protocol.
const (
- kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
- kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
- kexAlgoECDH256 = "ecdh-sha2-nistp256"
- kexAlgoECDH384 = "ecdh-sha2-nistp384"
- kexAlgoECDH521 = "ecdh-sha2-nistp521"
hostAlgoRSA = "ssh-rsa"
hostAlgoDSA = "ssh-dss"
compressionNone = "none"
@@ -53,48 +46,6 @@
CertAlgoECDSA521v01: crypto.SHA512,
}
-// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
-type dhGroup struct {
- g, p *big.Int
-}
-
-func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
- if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
- return nil, errors.New("ssh: DH parameter out of bounds")
- }
- return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
-}
-
-// dhGroup1 is the group called diffie-hellman-group1-sha1 in RFC 4253 and
-// Oakley Group 2 in RFC 2409.
-var dhGroup1 *dhGroup
-
-var dhGroup1Once sync.Once
-
-func initDHGroup1() {
- p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
-
- dhGroup1 = &dhGroup{
- g: new(big.Int).SetInt64(2),
- p: p,
- }
-}
-
-// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
-// Oakley Group 14 in RFC 3526.
-var dhGroup14 *dhGroup
-
-var dhGroup14Once sync.Once
-
-func initDHGroup14() {
- p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
-
- dhGroup14 = &dhGroup{
- g: new(big.Int).SetInt64(2),
- p: p,
- }
-}
-
// UnexpectedMessageError results when the SSH message that we received didn't
// match what we wanted.
type UnexpectedMessageError struct {
@@ -114,11 +65,6 @@
return fmt.Sprintf("ssh: parse error in message type %d", p.msgType)
}
-type handshakeMagics struct {
- clientVersion, serverVersion []byte
- clientKexInit, serverKexInit []byte
-}
-
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
for _, clientAlgo := range clientAlgos {
for _, serverAlgo := range serverAlgos {
diff --git a/ssh/kex.go b/ssh/kex.go
new file mode 100644
index 0000000..e6ce6c3
--- /dev/null
+++ b/ssh/kex.go
@@ -0,0 +1,387 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "errors"
+ "io"
+ "math/big"
+)
+
+const (
+ kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
+ kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
+ kexAlgoECDH256 = "ecdh-sha2-nistp256"
+ kexAlgoECDH384 = "ecdh-sha2-nistp384"
+ kexAlgoECDH521 = "ecdh-sha2-nistp521"
+)
+
+// kexResult captures the outcome of a key exchange.
+type kexResult struct {
+ // Session hash. See also RFC 4253, section 8.
+ H []byte
+
+ // Shared secret. See also RFC 4253, section 8.
+ K []byte
+
+ // Host key as hashed into H
+ HostKey []byte
+
+ // Signature of H
+ Signature []byte
+}
+
+// handshakeMagics contains data that is always included in the
+// session hash.
+type handshakeMagics struct {
+ clientVersion, serverVersion []byte
+ clientKexInit, serverKexInit []byte
+}
+
+func (m *handshakeMagics) write(w io.Writer) {
+ writeString(w, m.clientVersion)
+ writeString(w, m.serverVersion)
+ writeString(w, m.clientKexInit)
+ writeString(w, m.serverKexInit)
+}
+
+// kexAlgorithm abstracts different key exchange algorithms.
+type kexAlgorithm interface {
+ // Server runs server-side key agreement, signing the result
+ // with a hostkey.
+ Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error)
+
+ // Client runs the client-side key agreement. Caller is
+ // responsible for verifying the host key signature.
+ Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
+
+ // Hash returns a cryptographic hash function that matches the
+ // security level of the key exchange algorithm. It is used
+ // for calculating kexResult.H, and for deriving keys from
+ // data in kexResult.
+ Hash() crypto.Hash
+}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+ g, p *big.Int
+}
+
+func (group *dhGroup) Hash() crypto.Hash {
+ return crypto.SHA1
+}
+
+func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
+ if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
+ return nil, errors.New("ssh: DH parameter out of bounds")
+ }
+ return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
+}
+
+func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
+ hashFunc := crypto.SHA1
+
+ x, err := rand.Int(randSource, group.p)
+ if err != nil {
+ return nil, err
+ }
+ X := new(big.Int).Exp(group.g, x, group.p)
+ kexDHInit := kexDHInitMsg{
+ X: X,
+ }
+ if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
+ return nil, err
+ }
+
+ packet, err := c.readPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ var kexDHReply kexDHReplyMsg
+ if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
+ return nil, err
+ }
+
+ kInt, err := group.diffieHellman(kexDHReply.Y, x)
+ if err != nil {
+ return nil, err
+ }
+
+ h := hashFunc.New()
+ magics.write(h)
+ writeString(h, kexDHReply.HostKey)
+ writeInt(h, X)
+ writeInt(h, kexDHReply.Y)
+ K := make([]byte, intLength(kInt))
+ marshalInt(K, kInt)
+ h.Write(K)
+
+ return &kexResult{
+ H: h.Sum(nil),
+ K: K,
+ HostKey: kexDHReply.HostKey,
+ Signature: kexDHReply.Signature,
+ }, nil
+}
+
+func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+ hashFunc := crypto.SHA1
+ packet, err := c.readPacket()
+ if err != nil {
+ return
+ }
+ var kexDHInit kexDHInitMsg
+ if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+ return
+ }
+
+ y, err := rand.Int(randSource, group.p)
+ if err != nil {
+ return
+ }
+
+ Y := new(big.Int).Exp(group.g, y, group.p)
+ kInt, err := group.diffieHellman(kexDHInit.X, y)
+ if err != nil {
+ return nil, err
+ }
+
+ hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+
+ h := hashFunc.New()
+ magics.write(h)
+ writeString(h, hostKeyBytes)
+ writeInt(h, kexDHInit.X)
+ writeInt(h, Y)
+
+ K := make([]byte, intLength(kInt))
+ marshalInt(K, kInt)
+ h.Write(K)
+
+ H := h.Sum(nil)
+
+ // H is already a hash, but the hostkey signing will apply its
+ // own key-specific hash algorithm.
+ sig, err := signAndMarshal(priv, randSource, H)
+ if err != nil {
+ return nil, err
+ }
+
+ kexDHReply := kexDHReplyMsg{
+ HostKey: hostKeyBytes,
+ Y: Y,
+ Signature: sig,
+ }
+ packet = marshal(msgKexDHReply, kexDHReply)
+
+ err = c.writePacket(packet)
+ return &kexResult{
+ H: H,
+ K: K,
+ HostKey: hostKeyBytes,
+ Signature: sig,
+ }, nil
+}
+
+// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
+// described in RFC 5656, section 4.
+type ecdh struct {
+ curve elliptic.Curve
+}
+
+func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
+ ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
+ if err != nil {
+ return nil, err
+ }
+
+ kexInit := kexECDHInitMsg{
+ ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
+ }
+
+ serialized := marshal(msgKexECDHInit, kexInit)
+ if err := c.writePacket(serialized); err != nil {
+ return nil, err
+ }
+
+ packet, err := c.readPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ var reply kexECDHReplyMsg
+ if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
+ return nil, err
+ }
+
+ x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
+ if err != nil {
+ return nil, err
+ }
+
+ // generate shared secret
+ secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
+
+ h := ecHash(kex.curve).New()
+ magics.write(h)
+ writeString(h, reply.HostKey)
+ writeString(h, kexInit.ClientPubKey)
+ writeString(h, reply.EphemeralPubKey)
+ K := make([]byte, intLength(secret))
+ marshalInt(K, secret)
+ h.Write(K)
+
+ return &kexResult{
+ H: h.Sum(nil),
+ K: K,
+ HostKey: reply.HostKey,
+ Signature: reply.Signature,
+ }, nil
+}
+
+// unmarshalECKey parses and checks an EC key.
+func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
+ x, y = elliptic.Unmarshal(curve, pubkey)
+ if x == nil {
+ return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
+ }
+ if !validateECPublicKey(curve, x, y) {
+ return nil, nil, errors.New("ssh: public key not on curve")
+ }
+ return x, y, nil
+}
+
+// validateECPublicKey checks that the point is a valid public key for
+// the given curve. See [SEC1], 3.2.2
+func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
+ if x.Sign() == 0 && y.Sign() == 0 {
+ return false
+ }
+
+ if x.Cmp(curve.Params().P) >= 0 {
+ return false
+ }
+
+ if y.Cmp(curve.Params().P) >= 0 {
+ return false
+ }
+
+ if !curve.IsOnCurve(x, y) {
+ return false
+ }
+
+ // We don't check if N * PubKey == 0, since
+ //
+ // - the NIST curves have cofactor = 1, so this is implicit.
+ // (We don't forsee an implementation that supports non NIST
+ // curves)
+ //
+ // - for ephemeral keys, we don't need to worry about small
+ // subgroup attacks.
+ return true
+}
+
+func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+ packet, err := c.readPacket()
+ if err != nil {
+ return nil, err
+ }
+
+ var kexECDHInit kexECDHInitMsg
+ if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
+ return nil, err
+ }
+
+ clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
+ if err != nil {
+ return nil, err
+ }
+
+ // We could cache this key across multiple users/multiple
+ // connection attempts, but the benefit is small. OpenSSH
+ // generates a new key for each incoming connection.
+ ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
+ if err != nil {
+ return nil, err
+ }
+
+ hostKeyBytes := MarshalPublicKey(priv.PublicKey())
+
+ serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
+
+ // generate shared secret
+ secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
+
+ h := ecHash(kex.curve).New()
+ magics.write(h)
+ writeString(h, hostKeyBytes)
+ writeString(h, kexECDHInit.ClientPubKey)
+ writeString(h, serializedEphKey)
+
+ K := make([]byte, intLength(secret))
+ marshalInt(K, secret)
+ h.Write(K)
+
+ H := h.Sum(nil)
+
+ // H is already a hash, but the hostkey signing will apply its
+ // own key-specific hash algorithm.
+ sig, err := signAndMarshal(priv, rand, H)
+ if err != nil {
+ return nil, err
+ }
+
+ reply := kexECDHReplyMsg{
+ EphemeralPubKey: serializedEphKey,
+ HostKey: hostKeyBytes,
+ Signature: sig,
+ }
+
+ serialized := marshal(msgKexECDHReply, reply)
+ if err := c.writePacket(serialized); err != nil {
+ return nil, err
+ }
+
+ return &kexResult{
+ H: H,
+ K: K,
+ HostKey: reply.HostKey,
+ Signature: sig,
+ }, nil
+}
+
+func (kex *ecdh) Hash() crypto.Hash {
+ return ecHash(kex.curve)
+}
+
+var kexAlgoMap = map[string]kexAlgorithm{}
+
+func init() {
+ // This is the group called diffie-hellman-group1-sha1 in RFC
+ // 4253 and Oakley Group 2 in RFC 2409.
+ p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
+ kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
+ g: new(big.Int).SetInt64(2),
+ p: p,
+ }
+
+ // This is the group called diffie-hellman-group14-sha1 in RFC
+ // 4253 and Oakley Group 14 in RFC 3526.
+ p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+ kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
+ g: new(big.Int).SetInt64(2),
+ p: p,
+ }
+
+ kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
+ kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
+ kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
+}
diff --git a/ssh/kex_test.go b/ssh/kex_test.go
index 104009e..74e80c4 100644
--- a/ssh/kex_test.go
+++ b/ssh/kex_test.go
@@ -7,84 +7,73 @@
// Key exchange tests.
import (
- "fmt"
- "net"
+ "crypto/rand"
+ "io"
+ "reflect"
"testing"
)
-func pipe() (net.Conn, net.Conn, error) {
- l, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return nil, nil, err
- }
- conn1, err := net.Dial("tcp", l.Addr().String())
- if err != nil {
- return nil, nil, err
- }
-
- conn2, err := l.Accept()
- if err != nil {
- conn1.Close()
- return nil, nil, err
- }
- l.Close()
- return conn1, conn2, nil
+// An in-memory packetConn.
+type memTransport struct {
+ r, w chan []byte
}
-func testKexAlgorithm(algo string) error {
- crypto := CryptoConfig{
- KeyExchanges: []string{algo},
- }
- serverConfig := ServerConfig{
- PasswordCallback: func(conn *ServerConn, user, password string) bool {
- return password == "password"
- },
- Crypto: crypto,
+func (t *memTransport) readPacket() ([]byte, error) {
+ p, ok := <-t.r
+ if !ok {
+ return nil, io.EOF
}
- if err := serverConfig.SetRSAPrivateKey([]byte(testServerPrivateKey)); err != nil {
- return fmt.Errorf("SetRSAPrivateKey: %v", err)
- }
+ return p, nil
+}
- clientConfig := ClientConfig{
- User: "user",
- Auth: []ClientAuth{ClientAuthPassword(password("password"))},
- Crypto: crypto,
- }
-
- conn1, conn2, err := pipe()
- if err != nil {
- return err
- }
-
- defer conn1.Close()
- defer conn2.Close()
-
- server := Server(conn2, &serverConfig)
- serverHS := make(chan error, 1)
- go func() {
- serverHS <- server.Handshake()
- }()
-
- // Client runs the handshake.
- _, err = Client(conn1, &clientConfig)
- if err != nil {
- return fmt.Errorf("Client: %v", err)
- }
-
- if err := <-serverHS; err != nil {
- return fmt.Errorf("server.Handshake: %v", err)
- }
-
- // Here we could check that we now can send data between client &
- // server.
+func (t *memTransport) Close() error {
+ close(t.w)
return nil
}
-func TestKexAlgorithms(t *testing.T) {
- for _, algo := range []string{kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, kexAlgoDH1SHA1, kexAlgoDH14SHA1} {
- if err := testKexAlgorithm(algo); err != nil {
- t.Errorf("algorithm %s: %v", algo, err)
+func (t *memTransport) writePacket(p []byte) error {
+ t.w <- p
+ return nil
+}
+
+func memPipe() (a, b packetConn) {
+ p := make(chan []byte, 1)
+ q := make(chan []byte, 1)
+ return &memTransport{p, q}, &memTransport{q, p}
+}
+
+func TestKexes(t *testing.T) {
+ type kexResultErr struct {
+ result *kexResult
+ err error
+ }
+
+ for name, kex := range kexAlgoMap {
+ a, b := memPipe()
+
+ s := make(chan kexResultErr, 1)
+ c := make(chan kexResultErr, 1)
+ var magics handshakeMagics
+ go func() {
+ r, e := kex.Client(a, rand.Reader, &magics)
+ c <- kexResultErr{r, e}
+ }()
+ go func() {
+ r, e := kex.Server(b, rand.Reader, &magics, ecdsaKey)
+ s <- kexResultErr{r, e}
+ }()
+
+ clientRes := <-c
+ serverRes := <-s
+ if clientRes.err != nil {
+ t.Errorf("client: %v", clientRes.err)
+ }
+ if serverRes.err != nil {
+ t.Errorf("server: %v", serverRes.err)
+ }
+ if !reflect.DeepEqual(clientRes.result, serverRes.result) {
+ t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
}
}
}
diff --git a/ssh/server.go b/ssh/server.go
index ffc35dd..cb6fe8c 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -6,14 +6,11 @@
import (
"bytes"
- "crypto"
- "crypto/ecdsa"
- "crypto/elliptic"
"crypto/rand"
"encoding/binary"
"errors"
+ "fmt"
"io"
- "math/big"
"net"
"sync"
@@ -140,177 +137,6 @@
}
}
-// kexECDH performs Elliptic Curve Diffie-Hellman key agreement on a
-// ServerConnection, as documented in RFC 5656, section 4.
-func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
- packet, err := s.readPacket()
- if err != nil {
- return
- }
-
- var kexECDHInit kexECDHInitMsg
- if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
- return
- }
-
- clientX, clientY := elliptic.Unmarshal(curve, kexECDHInit.ClientPubKey)
- if clientX == nil {
- return nil, errors.New("ssh: elliptic.Unmarshal failure")
- }
-
- if !validateECPublicKey(curve, clientX, clientY) {
- return nil, errors.New("ssh: not a valid EC public key")
- }
-
- // We could cache this key across multiple users/multiple
- // connection attempts, but the benefit is small. OpenSSH
- // generates a new key for each incoming connection.
- ephKey, err := ecdsa.GenerateKey(curve, s.config.rand())
- if err != nil {
- return nil, err
- }
-
- hostKeyBytes := MarshalPublicKey(priv.PublicKey())
-
- serializedEphKey := elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
-
- // generate shared secret
- secret, _ := curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
-
- hashFunc := ecHash(curve)
- h := hashFunc.New()
- writeString(h, magics.clientVersion)
- writeString(h, magics.serverVersion)
- writeString(h, magics.clientKexInit)
- writeString(h, magics.serverKexInit)
- writeString(h, hostKeyBytes)
- writeString(h, kexECDHInit.ClientPubKey)
- writeString(h, serializedEphKey)
-
- K := make([]byte, intLength(secret))
- marshalInt(K, secret)
- h.Write(K)
-
- H := h.Sum(nil)
-
- // H is already a hash, but the hostkey signing will apply its
- // own key specific hash algorithm.
- sig, err := signAndMarshal(priv, s.config.rand(), H)
- if err != nil {
- return nil, err
- }
-
- reply := kexECDHReplyMsg{
- EphemeralPubKey: serializedEphKey,
- HostKey: hostKeyBytes,
- Signature: sig,
- }
-
- serialized := marshal(msgKexECDHReply, reply)
- if err := s.writePacket(serialized); err != nil {
- return nil, err
- }
-
- return &kexResult{
- H: H,
- K: K,
- HostKey: reply.HostKey,
- Hash: hashFunc,
- }, nil
-}
-
-// validateECPublicKey checks that the point is a valid public key for
-// the given curve. See [SEC1], 3.2.2
-func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
- if x.Sign() == 0 && y.Sign() == 0 {
- return false
- }
-
- if x.Cmp(curve.Params().P) >= 0 {
- return false
- }
-
- if y.Cmp(curve.Params().P) >= 0 {
- return false
- }
-
- if !curve.IsOnCurve(x, y) {
- return false
- }
-
- // We don't check if N * PubKey == 0, since
- //
- // - the NIST curves have cofactor = 1, so this is implicit.
- // (We don't forsee an implementation that supports non NIST
- // curves)
- //
- // - for ephemeral keys, we don't need to worry about small
- // subgroup attacks.
- return true
-}
-
-// kexDH performs Diffie-Hellman key agreement on a ServerConnection.
-func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
- packet, err := s.readPacket()
- if err != nil {
- return
- }
- var kexDHInit kexDHInitMsg
- if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
- return
- }
-
- y, err := rand.Int(s.config.rand(), group.p)
- if err != nil {
- return
- }
-
- Y := new(big.Int).Exp(group.g, y, group.p)
- kInt, err := group.diffieHellman(kexDHInit.X, y)
- if err != nil {
- return nil, err
- }
-
- hostKeyBytes := MarshalPublicKey(priv.PublicKey())
-
- h := hashFunc.New()
- writeString(h, magics.clientVersion)
- writeString(h, magics.serverVersion)
- writeString(h, magics.clientKexInit)
- writeString(h, magics.serverKexInit)
- writeString(h, hostKeyBytes)
- writeInt(h, kexDHInit.X)
- writeInt(h, Y)
-
- K := make([]byte, intLength(kInt))
- marshalInt(K, kInt)
- h.Write(K)
-
- H := h.Sum(nil)
-
- // H is already a hash, but the hostkey signing will apply its
- // own key specific hash algorithm.
- sig, err := signAndMarshal(priv, s.config.rand(), H)
- if err != nil {
- return nil, err
- }
-
- kexDHReply := kexDHReplyMsg{
- HostKey: hostKeyBytes,
- Y: Y,
- Signature: sig,
- }
- packet = marshal(msgKexDHReply, kexDHReply)
-
- err = s.writePacket(packet)
- return &kexResult{
- H: H,
- K: K,
- HostKey: hostKeyBytes,
- Hash: hashFunc,
- }, nil
-}
-
// signAndMarshal signs the data with the appropriate algorithm,
// and serializes the result in SSH wire format.
func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
@@ -330,8 +156,8 @@
if _, err = s.Write(serverVersion); err != nil {
return
}
- if err = s.Flush(); err != nil {
- return
+ if err := s.Flush(); err != nil {
+ return err
}
s.ClientVersion, err = readVersion(s)
@@ -415,32 +241,22 @@
}
}
- var magics handshakeMagics
- magics.serverVersion = serverVersion[:len(serverVersion)-2]
- magics.clientVersion = s.ClientVersion
- magics.serverKexInit = marshal(msgKexInit, serverKexInit)
- magics.clientKexInit = clientKexInitPacket
+ kex, ok := kexAlgoMap[kexAlgo]
+ if !ok {
+ return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+ }
- var result *kexResult
- switch kexAlgo {
- case kexAlgoECDH256:
- result, err = s.kexECDH(elliptic.P256(), &magics, hostKey)
- case kexAlgoECDH384:
- result, err = s.kexECDH(elliptic.P384(), &magics, hostKey)
- case kexAlgoECDH521:
- result, err = s.kexECDH(elliptic.P521(), &magics, hostKey)
- case kexAlgoDH14SHA1:
- dhGroup14Once.Do(initDHGroup14)
- result, err = s.kexDH(dhGroup14, crypto.SHA1, &magics, hostKey)
- case kexAlgoDH1SHA1:
- dhGroup1Once.Do(initDHGroup1)
- result, err = s.kexDH(dhGroup1, crypto.SHA1, &magics, hostKey)
- default:
- err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
+ magics := handshakeMagics{
+ serverVersion: serverVersion[:len(serverVersion)-2],
+ clientVersion: s.ClientVersion,
+ serverKexInit: marshal(msgKexInit, serverKexInit),
+ clientKexInit: clientKexInitPacket,
}
+ result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
if err != nil {
- return
+ return err
}
+
// sessionId must only be assigned during initial handshake.
if s.sessionId == nil {
s.sessionId = result.H
@@ -451,7 +267,7 @@
if err = s.writePacket([]byte{msgNewKeys}); err != nil {
return
}
- if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, result.Hash); err != nil {
+ if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
return
}
@@ -461,7 +277,7 @@
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
- if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, result.Hash); err != nil {
+ if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
return
}
@@ -760,10 +576,10 @@
}
c := &serverChan{
channel: channel{
- conn: s,
- remoteId: msg.PeersId,
- remoteWin: window{Cond: newCond()},
- maxPacket: msg.MaxPacketSize,
+ packetConn: s,
+ remoteId: msg.PeersId,
+ remoteWin: window{Cond: newCond()},
+ maxPacket: msg.MaxPacketSize,
},
chanType: msg.ChanType,
extraData: msg.TypeSpecificData,
diff --git a/ssh/transport.go b/ssh/transport.go
index c222caf..bbc4d80 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -29,13 +29,16 @@
maxPacket = 256 * 1024
)
-// conn represents an ssh transport that implements packet based
+// packetConn represents a transport that implements packet based
// operations.
-type conn interface {
+type packetConn interface {
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
- // Close closes the connection.
+ // Read a packet from the connection
+ readPacket() ([]byte, error)
+
+ // Close closes the write-side of the connection.
Close() error
}
@@ -74,7 +77,7 @@
}
// Read and decrypt a single packet from the remote peer.
-func (r *reader) readOnePacket() ([]byte, error) {
+func (r *reader) readPacket() ([]byte, error) {
var lengthBytes = make([]byte, 5)
var macSize uint32
if _, err := io.ReadFull(r, lengthBytes); err != nil {
@@ -128,7 +131,7 @@
// Read and decrypt next packet discarding debug and noop messages.
func (t *transport) readPacket() ([]byte, error) {
for {
- packet, err := t.readOnePacket()
+ packet, err := t.reader.readPacket()
if err != nil {
return nil, err
}