x/crypto/ssh: make sure the initial key exchange happens once.
This is done by running the key exchange and setting the session ID
under mutex. If the first exchange encounters an already set session
ID, then do nothing.
This fixes a race condition:
On setting up the connection, both sides sent a kexInit to initiate
the first (mandatory) key exchange. If one side was faster, the
faster side might have completed the key exchange, before the slow
side had a chance to send a kexInit. The slow side would send a
kexInit which would trigger a second key exchange. The resulting
confirmation message (msgNewKeys) would confuse the authentication
loop.
This fix removes sessionID from the transport struct.
This fix also deletes the unused interface rekeyingTransport.
Fixes #15066
Change-Id: I7f303bce5d3214c9bdd58f52d21178a185871d90
Reviewed-on: https://go-review.googlesource.com/21606
Reviewed-by: Adam Langley <agl@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/client.go b/ssh/client.go
index bc6f47a..e0f1a4d 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -97,7 +97,7 @@
c.transport = newClientTransport(
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
- if err := c.transport.requestKeyChange(); err != nil {
+ if err := c.transport.requestInitialKeyChange(); err != nil {
return err
}
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 1c54f75..08abd66 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -29,25 +29,6 @@
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
-
- // getSessionID returns the session ID. prepareKeyChange must
- // have been called once.
- getSessionID() []byte
-}
-
-// rekeyingTransport is the interface of handshakeTransport that we
-// (internally) expose to ClientConn and ServerConn.
-type rekeyingTransport interface {
- packetConn
-
- // requestKeyChange asks the remote side to change keys. All
- // writes are blocked until the key change succeeds, which is
- // signaled by reading a msgNewKeys.
- requestKeyChange() error
-
- // getSessionID returns the session ID. This is only valid
- // after the first key change has completed.
- getSessionID() []byte
}
// handshakeTransport implements rekeying on top of a keyingTransport
@@ -86,6 +67,9 @@
sentInitMsg *kexInitMsg
writtenSinceKex uint64
writeError error
+
+ // The session ID or nil if first kex did not complete yet.
+ sessionID []byte
}
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
@@ -122,7 +106,7 @@
}
func (t *handshakeTransport) getSessionID() []byte {
- return t.conn.getSessionID()
+ return t.sessionID
}
func (t *handshakeTransport) id() string {
@@ -183,9 +167,9 @@
if p[0] != msgKexInit {
return p, nil
}
- err = t.enterKeyExchange(p)
t.mu.Lock()
+ err = t.enterKeyExchangeLocked(p)
if err != nil {
// drop connection
t.conn.Close()
@@ -211,25 +195,39 @@
return []byte{msgNewKeys}, nil
}
+// keyChangeCategory describes whether a key exchange is the first on a
+// connection, or a subsequent one.
+type keyChangeCategory bool
+
+const (
+ firstKeyExchange keyChangeCategory = true
+ subsequentKeyExchange keyChangeCategory = false
+)
+
// sendKexInit sends a key change message, and returns the message
// that was sent. After initiating the key change, all writes will be
// blocked until the change is done, and a failed key change will
// close the underlying transport. This function is safe for
// concurrent use by multiple goroutines.
-func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
+func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
t.mu.Lock()
defer t.mu.Unlock()
- return t.sendKexInitLocked()
+ return t.sendKexInitLocked(isFirst)
+}
+
+func (t *handshakeTransport) requestInitialKeyChange() error {
+ _, _, err := t.sendKexInit(firstKeyExchange)
+ return err
}
func (t *handshakeTransport) requestKeyChange() error {
- _, _, err := t.sendKexInit()
+ _, _, err := t.sendKexInit(subsequentKeyExchange)
return err
}
// sendKexInitLocked sends a key change message. t.mu must be locked
// while this happens.
-func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
+func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
// kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
@@ -237,6 +235,14 @@
if t.sentInitMsg != nil {
return t.sentInitMsg, t.sentInitPacket, nil
}
+
+ // If this is the initial key change, but we already have a sessionID,
+ // then do nothing because the key exchange has already completed
+ // asynchronously.
+ if isFirst && t.sessionID != nil {
+ return nil, nil, nil
+ }
+
msg := &kexInitMsg{
KexAlgos: t.config.KeyExchanges,
CiphersClientServer: t.config.Ciphers,
@@ -276,7 +282,7 @@
defer t.mu.Unlock()
if t.writtenSinceKex > t.config.RekeyThreshold {
- t.sendKexInitLocked()
+ t.sendKexInitLocked(subsequentKeyExchange)
}
for t.sentInitMsg != nil && t.writeError == nil {
t.cond.Wait()
@@ -300,12 +306,12 @@
return t.conn.Close()
}
-// enterKeyExchange runs the key exchange.
-func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
+// enterKeyExchange runs the key exchange. t.mu must be held while running this.
+func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
if debugHandshake {
log.Printf("%s entered key exchange", t.id())
}
- myInit, myInitPacket, err := t.sendKexInit()
+ myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
if err != nil {
return err
}
@@ -362,6 +368,11 @@
return err
}
+ if t.sessionID == nil {
+ t.sessionID = result.H
+ result.SessionID = result.H
+ }
+
t.conn.prepareKeyChange(algs, result)
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err
@@ -371,6 +382,7 @@
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
+
return nil
}
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index bd7fe77..ef97d57 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -104,7 +104,7 @@
}
if i == 5 {
// halfway through, we request a key change.
- _, _, err := trC.sendKexInit()
+ _, _, err := trC.sendKexInit(subsequentKeyExchange)
if err != nil {
t.Fatalf("sendKexInit: %v", err)
}
@@ -161,7 +161,7 @@
}
// Now request a key change.
- _, _, err = trC.sendKexInit()
+ _, _, err = trC.sendKexInit(subsequentKeyExchange)
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
@@ -202,7 +202,7 @@
}
// Now request a key change.
- _, _, err = trC.sendKexInit()
+ _, _, err = trC.sendKexInit(subsequentKeyExchange)
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
@@ -215,7 +215,7 @@
}
// 2nd key change.
- _, _, err = trC.sendKexInit()
+ _, _, err = trC.sendKexInit(subsequentKeyExchange)
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
@@ -430,7 +430,7 @@
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
errMsg := &disconnectMsg{
- Reason: 42,
+ Reason: 42,
Message: "such is life",
}
trC.writePacket(Marshal(errMsg))
@@ -441,7 +441,7 @@
t.Fatalf("readPacket 1: %v", err)
}
if packet[0] != msgRequestSuccess {
- t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
+ t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
}
_, err = trS.readPacket()
diff --git a/ssh/kex.go b/ssh/kex.go
index 3ec603c..9285ee3 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -46,7 +46,7 @@
Hash crypto.Hash
// The session ID, which is the first H computed. This is used
- // to signal data inside transport.
+ // to derive key material inside the transport.
SessionID []byte
}
diff --git a/ssh/server.go b/ssh/server.go
index 4781eb7..d530501 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -188,7 +188,7 @@
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
- if err := s.transport.requestKeyChange(); err != nil {
+ if err := s.transport.requestInitialKeyChange(); err != nil {
return nil, err
}
diff --git a/ssh/transport.go b/ssh/transport.go
index 4de98a6..bf7dd61 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -39,19 +39,6 @@
rand io.Reader
io.Closer
-
- // Initial H used for the session ID. Once assigned this does
- // not change, even during subsequent key exchanges.
- sessionID []byte
-}
-
-// getSessionID returns the ID of the SSH connection. The return value
-// should not be modified.
-func (t *transport) getSessionID() []byte {
- if t.sessionID == nil {
- panic("session ID not set yet")
- }
- return t.sessionID
}
// packetCipher represents a combination of SSH encryption/MAC
@@ -81,12 +68,6 @@
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
- if t.sessionID == nil {
- t.sessionID = kexResult.H
- }
-
- kexResult.SessionID = t.sessionID
-
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
return err
} else {
@@ -119,7 +100,7 @@
case msgNewKeys:
select {
case cipher := <-s.pendingKeyChange:
- s.packetCipher = cipher
+ s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
}