ssh: cosmetic cleanups
These are the cosmetic cleanups from the bits of code that I
rereviewed.
1) stringLength now takes a int; the length of the string. Too many
callers were allocating with stringLength([]byte(s)) and
stringLength only needs to call len().
2) agent.go now has sendAndReceive to remove logic that was
duplicated.
3) We now reject negative DH values
4) We now reject empty packets rather than crashing.
R=dave, jonathan.mark.pittman
CC=golang-dev
https://golang.org/cl/6061052
diff --git a/ssh/agent.go b/ssh/agent.go
index f63ce6d..3c701b7 100644
--- a/ssh/agent.go
+++ b/ssh/agent.go
@@ -10,7 +10,6 @@
import (
"encoding/base64"
"errors"
- "fmt"
"io"
)
@@ -44,6 +43,10 @@
agentConstrainConfirm = 2
)
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
+// is a sanity check, not a limit in the spec.
+const maxAgentResponseBytes = 16 << 20
+
// Agent messages:
// These structures mirror the wire format of the corresponding ssh agent
// messages found in PROTOCOL.agent.
@@ -85,18 +88,16 @@
func (ak *AgentKey) String() string {
algo, _, ok := parseString(ak.blob)
if !ok {
- return "malformed key"
+ return "ssh: malformed key"
}
- algoName := string(algo)
- b64EncKey := base64.StdEncoding.EncodeToString(ak.blob)
- comment := ""
+ s := string(algo) + " " + base64.StdEncoding.EncodeToString(ak.blob)
if ak.Comment != "" {
- comment = " " + ak.Comment
+ s += " " + ak.Comment
}
- return fmt.Sprintf("%s %s%s", algoName, b64EncKey, comment)
+ return s
}
// Key returns an agent's public key as a *rsa.PublicKey, *dsa.PublicKey, or
@@ -131,50 +132,51 @@
io.ReadWriter
}
-func (ac *AgentClient) sendRequest(req []byte) error {
- msg := make([]byte, stringLength(req))
+// sendAndReceive sends req to the agent and waits for a reply. On success,
+// the reply is unmarshaled into reply and replyType is set to the first byte of
+// the reply, which contains the type of the message.
+func (ac *AgentClient) sendAndReceive(req []byte) (reply interface{}, replyType uint8, err error) {
+ msg := make([]byte, stringLength(len(req)))
marshalString(msg, req)
- if _, err := ac.Write(msg); err != nil {
- return err
+ if _, err = ac.Write(msg); err != nil {
+ return
}
- return nil
-}
-func (ac *AgentClient) readResponse() ([]byte, error) {
var respSizeBuf [4]byte
- if _, err := io.ReadFull(ac, respSizeBuf[:]); err != nil {
- return nil, err
+ if _, err = io.ReadFull(ac, respSizeBuf[:]); err != nil {
+ return
}
+ respSize, _, _ := parseUint32(respSizeBuf[:])
- respSize, _, ok := parseUint32(respSizeBuf[:])
- if !ok {
- return nil, errors.New("ssh: failure to parse response size")
+ if respSize > maxAgentResponseBytes {
+ err = errors.New("ssh: agent reply too large")
+ return
}
buf := make([]byte, respSize)
- if _, err := io.ReadFull(ac, buf); err != nil {
- return nil, err
+ if _, err = io.ReadFull(ac, buf); err != nil {
+ return
}
- return buf, nil
+ return unmarshalAgentMsg(buf)
}
// RequestIdentities queries the agent for protocol 2 keys as defined in
// PROTOCOL.agent section 2.5.2.
func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
req := marshal(agentRequestIdentities, requestIdentitiesAgentMsg{})
- if err := ac.sendRequest(req); err != nil {
- return nil, err
- }
- resp, err := ac.readResponse()
+ msg, msgType, err := ac.sendAndReceive(req)
if err != nil {
return nil, err
}
- switch msg := decodeAgentMsg(resp).(type) {
+ switch msg := msg.(type) {
case *identitiesAnswerAgentMsg:
+ if msg.NumKeys > maxAgentResponseBytes/8 {
+ return nil, errors.New("ssh: too many keys in agent reply")
+ }
keys := make([]*AgentKey, msg.NumKeys)
- data := msg.Keys[:]
+ data := msg.Keys
for i := uint32(0); i < msg.NumKeys; i++ {
var key *AgentKey
var ok bool
@@ -185,11 +187,9 @@
}
return keys, nil
case *failureAgentMsg:
- return nil, errors.New("ssh: failed to list keys.")
- case ParseError, UnexpectedMessageError:
- return nil, msg.(error)
+ return nil, errors.New("ssh: failed to list keys")
}
- return nil, UnexpectedMessageError{agentIdentitiesAnswer, resp[0]}
+ return nil, UnexpectedMessageError{agentIdentitiesAnswer, msgType}
}
// SignRequest requests the signing of data by the agent using a protocol 2 key
@@ -200,29 +200,26 @@
KeyBlob: serializePublickey(key),
Data: data,
})
- if err := ac.sendRequest(req); err != nil {
- return nil, err
- }
- resp, err := ac.readResponse()
+ msg, msgType, err := ac.sendAndReceive(req)
if err != nil {
return nil, err
}
- switch msg := decodeAgentMsg(resp).(type) {
+ switch msg := msg.(type) {
case *signResponseAgentMsg:
return msg.SigBlob, nil
case *failureAgentMsg:
return nil, errors.New("ssh: failed to sign challenge")
- case ParseError, UnexpectedMessageError:
- return nil, msg.(error)
}
- return nil, UnexpectedMessageError{agentSignResponse, resp[0]}
+ return nil, UnexpectedMessageError{agentSignResponse, msgType}
}
-func decodeAgentMsg(packet []byte) interface{} {
+// unmarshalAgentMsg parses an agent message in packet, returning the parsed
+// form and the message type of packet.
+func unmarshalAgentMsg(packet []byte) (interface{}, uint8, error) {
if len(packet) < 1 {
- return ParseError{0}
+ return nil, 0, ParseError{0}
}
var msg interface{}
switch packet[0] {
@@ -235,10 +232,10 @@
case agentSignResponse:
msg = new(signResponseAgentMsg)
default:
- return UnexpectedMessageError{0, packet[0]}
+ return nil, 0, UnexpectedMessageError{0, packet[0]}
}
if err := unmarshal(msg, packet, packet[0]); err != nil {
- return err
+ return nil, 0, err
}
- return msg
+ return msg, packet[0], nil
}
diff --git a/ssh/certs.go b/ssh/certs.go
index 59430ea..107fd1a 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -154,18 +154,18 @@
sigKey := serializePublickey(cert.SignatureKey)
- length := stringLength(cert.Nonce)
+ length := stringLength(len(cert.Nonce))
length += len(pubKey)
length += 8 // Length of Serial
length += 4 // Length of Type
- length += stringLength([]byte(cert.KeyId))
+ length += stringLength(len(cert.KeyId))
length += lengthPrefixedNameListLength(cert.ValidPrincipals)
length += 8 // Length of ValidAfter
length += 8 // Length of ValidBefore
length += tupleListLength(cert.CriticalOptions)
length += tupleListLength(cert.Extensions)
- length += stringLength(cert.Reserved)
- length += stringLength(sigKey)
+ length += stringLength(len(cert.Reserved))
+ length += stringLength(len(sigKey))
length += signatureLength(cert.Signature)
ret := make([]byte, length)
@@ -215,9 +215,7 @@
for len(list) > 0 {
var next []byte
- var ok bool
- next, list, ok = parseString(list)
- if !ok {
+ if next, list, ok = parseString(list); !ok {
return nil, nil, false
}
out = append(out, string(next))
@@ -272,8 +270,8 @@
func signatureLength(sig *signature) int {
length := 4 // length prefix for signature
- length += stringLength([]byte(sig.Format))
- length += stringLength(sig.Blob)
+ length += stringLength(len(sig.Format))
+ length += stringLength(len(sig.Blob))
return length
}
diff --git a/ssh/channel.go b/ssh/channel.go
index 9ee43e4..20bc710 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -56,7 +56,7 @@
}
func (c ChannelRequest) Error() string {
- return "channel request received"
+ return "ssh: channel request received"
}
// RejectionReason is an enumeration used when rejecting channel creation
@@ -255,7 +255,7 @@
}
if c.length > 0 {
- tail := min(c.head + c.length, len(c.pendingData))
+ tail := min(c.head+c.length, len(c.pendingData))
n = copy(data, c.pendingData[c.head:tail])
c.head += n
c.length -= n
@@ -374,18 +374,17 @@
return c.serverConn.err
}
- if ok {
- ack := channelRequestSuccessMsg{
- PeersId: c.theirId,
- }
- return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
- } else {
+ if !ok {
ack := channelRequestFailureMsg{
PeersId: c.theirId,
}
return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
}
- panic("unreachable")
+
+ ack := channelRequestSuccessMsg{
+ PeersId: c.theirId,
+ }
+ return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
}
func (c *channel) ChannelType() string {
diff --git a/ssh/cipher.go b/ssh/cipher.go
index d91929a..0646e1e 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -35,10 +35,10 @@
}
type cipherMode struct {
- keySize int
- ivSize int
- skip int
- createFn func(key, iv []byte) (cipher.Stream, error)
+ keySize int
+ ivSize int
+ skip int
+ createFunc func(key, iv []byte) (cipher.Stream, error)
}
func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
@@ -49,7 +49,7 @@
panic("ssh: iv too small for cipher")
}
- stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
+ stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
diff --git a/ssh/client.go b/ssh/client.go
index 973dd3e..493d8ec 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -154,16 +154,16 @@
return nil, nil, err
}
- var kexDHReply = new(kexDHReplyMsg)
- if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
+ var kexDHReply kexDHReplyMsg
+ if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
return nil, nil, err
}
- if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
- return nil, nil, errors.New("server DH parameter out of bounds")
+ kInt, err := group.diffieHellman(kexDHReply.Y, x)
+ if err != nil {
+ return nil, nil, err
}
- kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
h := hashFunc.New()
writeString(h, magics.clientVersion)
writeString(h, magics.serverVersion)
@@ -352,7 +352,7 @@
case *channelOpenFailureMsg:
return errors.New(safeString(msg.Message))
}
- return errors.New("unexpected packet")
+ return errors.New("ssh: unexpected packet")
}
// sendEOF sends EOF to the server. RFC 4254 Section 5.3
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 7ae670f..8c030ac 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -213,7 +213,7 @@
}
// manually wrap the serialized signature in a string
s := serializeSignature(algoname, sign)
- sig := make([]byte, stringLength(s))
+ sig := make([]byte, stringLength(len(s)))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index c1c57ba..9d1ae2f 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -85,7 +85,7 @@
case *rsa.PrivateKey:
return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
}
- return nil, errors.New("unknown key type")
+ return nil, errors.New("ssh: unknown key type")
}
func (k *keychain) loadPEM(file string) error {
diff --git a/ssh/common.go b/ssh/common.go
index 429b488..e94142c 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -7,6 +7,7 @@
import (
"crypto/dsa"
"crypto/rsa"
+ "errors"
"math/big"
"strconv"
"sync"
@@ -32,6 +33,13 @@
g, p *big.Int
}
+func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
+ if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
+ return nil, errors.New("ssh: DH parameter out of bounds")
+ }
+ return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
+}
+
// dhGroup1 is the group called diffie-hellman-group1-sha1 in RFC 4253 and
// Oakley Group 2 in RFC 2409.
var dhGroup1 *dhGroup
@@ -178,8 +186,8 @@
case hostAlgoDSACertV01:
algoname = "ssh-dss"
}
- length := stringLength([]byte(algoname))
- length += stringLength(sig)
+ length := stringLength(len(algoname))
+ length += stringLength(len(sig))
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
@@ -203,7 +211,7 @@
panic("unexpected key type")
}
- length := stringLength([]byte(algoname))
+ length := stringLength(len(algoname))
length += len(pubKeyBytes)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
@@ -230,14 +238,14 @@
service := []byte(req.Service)
method := []byte(req.Method)
- length := stringLength(sessionId)
+ length := stringLength(len(sessionId))
length += 1
- length += stringLength(user)
- length += stringLength(service)
- length += stringLength(method)
+ length += stringLength(len(user))
+ length += stringLength(len(service))
+ length += stringLength(len(method))
length += 1
- length += stringLength(algo)
- length += stringLength(pubKey)
+ length += stringLength(len(algo))
+ length += stringLength(len(pubKey))
ret := make([]byte, length)
r := marshalString(ret, sessionId)
diff --git a/ssh/keys.go b/ssh/keys.go
index 8322697..1f37864 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -78,7 +78,7 @@
// marshalPrivRSA serializes an RSA private key according to RFC 4253, section 6.6.
func marshalPrivRSA(priv *rsa.PrivateKey) []byte {
e := new(big.Int).SetInt64(int64(priv.E))
- length := stringLength([]byte(hostAlgoRSA))
+ length := stringLength(len(hostAlgoRSA))
length += intLength(e)
length += intLength(priv.N)
diff --git a/ssh/messages.go b/ssh/messages.go
index d61c6c7..3efe81f 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -543,8 +543,8 @@
w.Write(s)
}
-func stringLength(s []byte) int {
- return 4 + len(s)
+func stringLength(n int) int {
+ return 4 + n
}
func marshalString(to []byte, s []byte) []byte {
diff --git a/ssh/server.go b/ssh/server.go
index e669b5c..155e685 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -141,24 +141,23 @@
return
}
- if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
- return nil, nil, errors.New("client DH parameter out of bounds")
- }
-
y, err := rand.Int(s.config.rand(), group.p)
if err != nil {
return
}
Y := new(big.Int).Exp(group.g, y, group.p)
- kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
+ kInt, err := group.diffieHellman(kexDHInit.X, y)
+ if err != nil {
+ return nil, nil, err
+ }
var serializedHostKey []byte
switch hostKeyAlgo {
case hostAlgoRSA:
serializedHostKey = s.config.rsaSerialized
default:
- return nil, nil, errors.New("internal error")
+ return nil, nil, errors.New("ssh: internal error")
}
h := hashFunc.New()
@@ -187,7 +186,7 @@
return
}
default:
- return nil, nil, errors.New("internal error")
+ return nil, nil, errors.New("ssh: internal error")
}
serializedSig := serializeSignature(hostAlgoRSA, sig)
diff --git a/ssh/session.go b/ssh/session.go
index ea4addb..7948c00 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -231,9 +231,9 @@
case *channelRequestSuccessMsg:
return nil
case *channelRequestFailureMsg:
- return errors.New("request failed")
+ return errors.New("ssh: request failed")
}
- return fmt.Errorf("unknown packet %T received: %v", msg, msg)
+ return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg)
}
func (s *Session) start() error {
diff --git a/ssh/transport.go b/ssh/transport.go
index c76116a..f253ce5 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -105,10 +105,10 @@
paddingLength := uint32(lengthBytes[4])
if length <= paddingLength+1 {
- return nil, errors.New("invalid packet length")
+ return nil, errors.New("ssh: invalid packet length")
}
if length > maxPacketSize {
- return nil, errors.New("packet too large")
+ return nil, errors.New("ssh: packet too large")
}
packet := make([]byte, length-1+macSize)
@@ -136,6 +136,9 @@
if err != nil {
return nil, err
}
+ if len(packet) == 0 {
+ return nil, errors.New("ssh: zero length packet")
+ }
if packet[0] != msgIgnore && packet[0] != msgDebug {
return packet, nil
}