ssh: rewrite (re)keying logic.
Use channels and a dedicated write loop for managing the rekeying
process. This lets us collect packets to be written while a key
exchange is in progress.
Previously, the read loop ran the key exchange, and writers would
block if a key exchange was going on. If a reader wrote back a packet
while processing a read packet, it could block, stopping the read
loop, thus causing a deadlock. Such coupled read/writes are inherent
with handling requests that want a response (eg. keepalive,
opening/closing channels etc.). The buffered channels (most channels
have capacity 16) papered over these problems, but under load SSH
connections would occasionally deadlock.
Fixes #18439.
Change-Id: I7c14ff4991fa3100a5d36025125d0cf1119c471d
Reviewed-on: https://go-review.googlesource.com/35012
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/client.go b/ssh/client.go
index 0212a20..c841e8d 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -97,13 +97,11 @@
c.transport = newClientTransport(
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
- if err := c.transport.requestInitialKeyChange(); err != nil {
+ if err := c.transport.waitSession(); err != nil {
return err
}
- // We just did the key change, so the session ID is established.
c.sessionID = c.transport.getSessionID()
-
return c.clientAuthenticate(config)
}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index 294af0d..fd1ec5d 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -30,8 +30,10 @@
// then any untried methods suggested by the server.
tried := make(map[string]bool)
var lastMethods []string
+
+ sessionID := c.transport.getSessionID()
for auth := AuthMethod(new(noneAuth)); auth != nil; {
- ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
+ ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
if err != nil {
return err
}
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 37d42e4..03c950d 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -53,6 +53,20 @@
incoming chan []byte
readError error
+ mu sync.Mutex
+ writeError error
+ sentInitPacket []byte
+ sentInitMsg *kexInitMsg
+ pendingPackets [][]byte // Used when a key exchange is in progress.
+
+ // If the read loop wants to schedule a kex, it pings this
+ // channel, and the write loop will send out a kex message.
+ requestKex chan struct{}
+
+ // If the other side requests or confirms a kex, its kexInit
+ // packet is sent here for the write loop to find it.
+ startKex chan *pendingKex
+
// data for host key checking
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
dialAddress string
@@ -60,27 +74,28 @@
readSinceKex uint64
- // Protects the writing side of the connection
- mu sync.Mutex
- cond *sync.Cond
- sentInitPacket []byte
- sentInitMsg *kexInitMsg
writtenSinceKex uint64
- writeError error
// The session ID or nil if first kex did not complete yet.
sessionID []byte
}
+type pendingKex struct {
+ otherInit []byte
+ done chan error
+}
+
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
t := &handshakeTransport{
conn: conn,
serverVersion: serverVersion,
clientVersion: clientVersion,
incoming: make(chan []byte, 16),
- config: config,
+ requestKex: make(chan struct{}, 1),
+ startKex: make(chan *pendingKex, 1),
+
+ config: config,
}
- t.cond = sync.NewCond(&t.mu)
return t
}
@@ -95,6 +110,7 @@
t.hostKeyAlgorithms = supportedHostKeyAlgos
}
go t.readLoop()
+ go t.kexLoop()
return t
}
@@ -102,6 +118,7 @@
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.hostKeys = config.hostKeys
go t.readLoop()
+ go t.kexLoop()
return t
}
@@ -109,6 +126,20 @@
return t.sessionID
}
+// waitSession waits for the session to be established. This should be
+// the first thing to call after instantiating handshakeTransport.
+func (t *handshakeTransport) waitSession() error {
+ p, err := t.readPacket()
+ if err != nil {
+ return err
+ }
+ if p[0] != msgNewKeys {
+ return fmt.Errorf("ssh: first packet should be msgNewKeys")
+ }
+
+ return nil
+}
+
func (t *handshakeTransport) id() string {
if len(t.hostKeys) > 0 {
return "server"
@@ -116,6 +147,19 @@
return "client"
}
+func (t *handshakeTransport) printPacket(p []byte, write bool) {
+ action := "got"
+ if write {
+ action = "sent"
+ }
+ if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
+ log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
+ } else {
+ msg, err := decode(p)
+ log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
+ }
+}
+
func (t *handshakeTransport) readPacket() ([]byte, error) {
p, ok := <-t.incoming
if !ok {
@@ -125,8 +169,16 @@
}
func (t *handshakeTransport) readLoop() {
+ // We always start with the mandatory key exchange. We use
+ // the channel for simplicity, and this works if we can rely
+ // on the SSH package itself not doing anything else before
+ // waitSession has completed.
+ t.requestKeyExchange()
+
+ first := true
for {
- p, err := t.readOnePacket()
+ p, err := t.readOnePacket(first)
+ first = false
if err != nil {
t.readError = err
close(t.incoming)
@@ -138,20 +190,123 @@
t.incoming <- p
}
- // If we can't read, declare the writing part dead too.
- t.mu.Lock()
- defer t.mu.Unlock()
- if t.writeError == nil {
- t.writeError = t.readError
- }
- t.cond.Broadcast()
+ // Stop writers too.
+ t.recordWriteError(t.readError)
+
+ // Unblock the writer should it wait for this.
+ close(t.startKex)
+
+ // Don't close t.requestKex; it's also written to from writePacket.
}
-func (t *handshakeTransport) readOnePacket() ([]byte, error) {
- if t.readSinceKex > t.config.RekeyThreshold {
- if err := t.requestKeyChange(); err != nil {
- return nil, err
+func (t *handshakeTransport) pushPacket(p []byte) error {
+ if debugHandshake {
+ t.printPacket(p, true)
+ }
+ return t.conn.writePacket(p)
+}
+
+func (t *handshakeTransport) getWriteError() error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.writeError
+}
+
+func (t *handshakeTransport) recordWriteError(err error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.writeError == nil && err != nil {
+ t.writeError = err
+ }
+}
+
+func (t *handshakeTransport) requestKeyExchange() {
+ select {
+ case t.requestKex <- struct{}{}:
+ default:
+ // something already requested a kex, so do nothing.
+ }
+
+}
+
+func (t *handshakeTransport) kexLoop() {
+write:
+ for t.getWriteError() == nil {
+ var request *pendingKex
+ var sent bool
+
+ for request == nil || !sent {
+ var ok bool
+ select {
+ case request, ok = <-t.startKex:
+ if !ok {
+ break write
+ }
+ case <-t.requestKex:
+ }
+
+ if !sent {
+ if err := t.sendKexInit(); err != nil {
+ t.recordWriteError(err)
+ break
+ }
+ sent = true
+ }
}
+
+ if err := t.getWriteError(); err != nil {
+ if request != nil {
+ request.done <- err
+ }
+ break
+ }
+
+ // We're not servicing t.requestKex, but that is OK:
+ // we never block on sending to t.requestKex.
+
+ // We're not servicing t.startKex, but the remote end
+ // has just sent us a kexInitMsg, so it can't send
+ // another key change request.
+
+ err := t.enterKeyExchange(request.otherInit)
+
+ t.mu.Lock()
+ t.writeError = err
+ t.sentInitPacket = nil
+ t.sentInitMsg = nil
+ t.writtenSinceKex = 0
+ request.done <- t.writeError
+
+ // kex finished. Push packets that we received while
+ // the kex was in progress. Don't look at t.startKex
+ // and don't increment writtenSinceKex: if we trigger
+ // another kex while we are still busy with the last
+ // one, things will become very confusing.
+ for _, p := range t.pendingPackets {
+ t.writeError = t.pushPacket(p)
+ if t.writeError != nil {
+ break
+ }
+ }
+ t.pendingPackets = t.pendingPackets[0:]
+ t.mu.Unlock()
+ }
+
+ // drain startKex channel. We don't service t.requestKex
+ // because nobody does blocking sends there.
+ go func() {
+ for init := range t.startKex {
+ init.done <- t.writeError
+ }
+ }()
+
+ // Unblock reader.
+ t.conn.Close()
+}
+
+func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
+ if t.readSinceKex > t.config.RekeyThreshold {
+ t.requestKeyExchange()
}
p, err := t.conn.readPacket()
@@ -161,39 +316,30 @@
t.readSinceKex += uint64(len(p))
if debugHandshake {
- if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
- log.Printf("%s got data (packet %d bytes)", t.id(), len(p))
- } else {
- msg, err := decode(p)
- log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
- }
+ t.printPacket(p, false)
}
+
+ if first && p[0] != msgKexInit {
+ return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
+ }
+
if p[0] != msgKexInit {
return p, nil
}
- t.mu.Lock()
-
firstKex := t.sessionID == nil
- err = t.enterKeyExchangeLocked(p)
- if err != nil {
- // drop connection
- t.conn.Close()
- t.writeError = err
+ kex := pendingKex{
+ done: make(chan error, 1),
+ otherInit: p,
}
+ t.startKex <- &kex
+ err = <-kex.done
if debugHandshake {
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
}
- // Unblock writers.
- t.sentInitMsg = nil
- t.sentInitPacket = nil
- t.cond.Broadcast()
- t.writtenSinceKex = 0
- t.mu.Unlock()
-
if err != nil {
return nil, err
}
@@ -213,61 +359,16 @@
return successPacket, nil
}
-// keyChangeCategory describes whether a key exchange is the first on a
-// connection, or a subsequent one.
-type keyChangeCategory bool
-
-const (
- firstKeyExchange keyChangeCategory = true
- subsequentKeyExchange keyChangeCategory = false
-)
-
-// sendKexInit sends a key change message, and returns the message
-// that was sent. After initiating the key change, all writes will be
-// blocked until the change is done, and a failed key change will
-// close the underlying transport. This function is safe for
-// concurrent use by multiple goroutines.
-func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) error {
- var err error
-
+// sendKexInit sends a key change message.
+func (t *handshakeTransport) sendKexInit() error {
t.mu.Lock()
- // If this is the initial key change, but we already have a sessionID,
- // then do nothing because the key exchange has already completed
- // asynchronously.
- if !isFirst || t.sessionID == nil {
- _, _, err = t.sendKexInitLocked(isFirst)
- }
- t.mu.Unlock()
- if err != nil {
- return err
- }
- if isFirst {
- if packet, err := t.readPacket(); err != nil {
- return err
- } else if packet[0] != msgNewKeys {
- return unexpectedMessageError(msgNewKeys, packet[0])
- }
- }
- return nil
-}
-
-func (t *handshakeTransport) requestInitialKeyChange() error {
- return t.sendKexInit(firstKeyExchange)
-}
-
-func (t *handshakeTransport) requestKeyChange() error {
- return t.sendKexInit(subsequentKeyExchange)
-}
-
-// sendKexInitLocked sends a key change message. t.mu must be locked
-// while this happens.
-func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
- // kexInits may be sent either in response to the other side,
- // or because our side wants to initiate a key change, so we
- // may have already sent a kexInit. In that case, don't send a
- // second kexInit.
+ defer t.mu.Unlock()
if t.sentInitMsg != nil {
- return t.sentInitMsg, t.sentInitPacket, nil
+ // kexInits may be sent either in response to the other side,
+ // or because our side wants to initiate a key change, so we
+ // may have already sent a kexInit. In that case, don't send a
+ // second kexInit.
+ return nil
}
msg := &kexInitMsg{
@@ -295,53 +396,57 @@
packetCopy := make([]byte, len(packet))
copy(packetCopy, packet)
- if err := t.conn.writePacket(packetCopy); err != nil {
- return nil, nil, err
+ if err := t.pushPacket(packetCopy); err != nil {
+ return err
}
t.sentInitMsg = msg
t.sentInitPacket = packet
- return msg, packet, nil
+
+ return nil
}
func (t *handshakeTransport) writePacket(p []byte) error {
- t.mu.Lock()
- defer t.mu.Unlock()
-
- if t.writtenSinceKex > t.config.RekeyThreshold {
- t.sendKexInitLocked(subsequentKeyExchange)
- }
- for t.sentInitMsg != nil && t.writeError == nil {
- t.cond.Wait()
- }
- if t.writeError != nil {
- return t.writeError
- }
- t.writtenSinceKex += uint64(len(p))
-
switch p[0] {
case msgKexInit:
return errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
return errors.New("ssh: only handshakeTransport can send newKeys")
- default:
- return t.conn.writePacket(p)
}
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.writeError != nil {
+ return t.writeError
+ }
+
+ if t.sentInitMsg != nil {
+ // Copy the packet so the writer can reuse the buffer.
+ cp := make([]byte, len(p))
+ copy(cp, p)
+ t.pendingPackets = append(t.pendingPackets, cp)
+ return nil
+ }
+ t.writtenSinceKex += uint64(len(p))
+ if t.writtenSinceKex > t.config.RekeyThreshold {
+ t.requestKeyExchange()
+ }
+
+ if err := t.pushPacket(p); err != nil {
+ t.writeError = err
+ }
+
+ return nil
}
func (t *handshakeTransport) Close() error {
return t.conn.Close()
}
-// enterKeyExchange runs the key exchange. t.mu must be held while running this.
-func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
+func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
if debugHandshake {
log.Printf("%s entered key exchange", t.id())
}
- myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
- if err != nil {
- return err
- }
otherInit := &kexInitMsg{}
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
@@ -352,16 +457,15 @@
clientVersion: t.clientVersion,
serverVersion: t.serverVersion,
clientKexInit: otherInitPacket,
- serverKexInit: myInitPacket,
+ serverKexInit: t.sentInitPacket,
}
clientInit := otherInit
- serverInit := myInit
+ serverInit := t.sentInitMsg
if len(t.hostKeys) == 0 {
- clientInit = myInit
- serverInit = otherInit
+ clientInit, serverInit = serverInit, clientInit
- magics.clientKexInit = myInitPacket
+ magics.clientKexInit = t.sentInitPacket
magics.serverKexInit = otherInitPacket
}
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 4cf5368..530d7d2 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -110,6 +110,13 @@
serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf)
+ if err := server.waitSession(); err != nil {
+ return nil, nil, fmt.Errorf("server.waitSession: %v", err)
+ }
+ if err := client.waitSession(); err != nil {
+ return nil, nil, fmt.Errorf("client.waitSession: %v", err)
+ }
+
return client, server, nil
}
@@ -117,8 +124,9 @@
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
- checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", true)
+
+ checker := &syncChecker{make(chan int, 10)}
+ trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
@@ -126,7 +134,11 @@
defer trC.Close()
defer trS.Close()
+ <-checker.called
+
clientDone := make(chan int, 0)
+ gotHalf := make(chan int, 0)
+
go func() {
defer close(clientDone)
// Client writes a bunch of stuff, and does a key
@@ -138,33 +150,35 @@
t.Fatalf("sendPacket: %v", err)
}
if i == 5 {
+ <-gotHalf
// halfway through, we request a key change.
- err := trC.sendKexInit(subsequentKeyExchange)
- if err != nil {
- t.Fatalf("sendKexInit: %v", err)
- }
+ trC.requestKeyExchange()
+
+ // Wait until we can be sure the key
+ // change has really started before we
+ // write more.
+ <-checker.called
}
}
- trC.Close()
}()
// Server checks that client messages come in cleanly
i := 0
err = nil
- for {
+ for ; i < 10; i++ {
var p []byte
p, err = trS.readPacket()
if err != nil {
break
}
- if p[0] == msgNewKeys {
- continue
+ if i == 5 {
+ gotHalf <- 1
}
+
want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %q, want %q", i, p, want)
}
- i++
}
<-clientDone
if err != nil && err != io.EOF {
@@ -174,150 +188,58 @@
t.Errorf("received %d messages, want 10.", i)
}
- // If all went well, we registered exactly 1 key change.
- if len(checker.calls) != 1 {
- t.Fatalf("got %d host key checks, want 1", len(checker.calls))
- }
-
- pub := testSigners["ecdsa"].PublicKey()
- want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
- if want != checker.calls[0] {
- t.Errorf("got %q want %q for host key check", checker.calls[0], want)
- }
-
-}
-
-func TestHandshakeError(t *testing.T) {
- checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad", false)
- if err != nil {
- t.Fatalf("handshakePair: %v", err)
- }
- defer trC.Close()
- defer trS.Close()
-
- // send a packet
- packet := []byte{msgRequestSuccess, 42}
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
-
- // Now request a key change.
- err = trC.sendKexInit(subsequentKeyExchange)
- if err != nil {
- t.Errorf("sendKexInit: %v", err)
- }
-
- // the key change will fail, and afterwards we can't write.
- if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
- t.Errorf("writePacket after botched rekey succeeded.")
- }
-
- readback, err := trS.readPacket()
- if err != nil {
- t.Fatalf("server closed too soon: %v", err)
- }
- if bytes.Compare(readback, packet) != 0 {
- t.Errorf("got %q want %q", readback, packet)
- }
- readback, err = trS.readPacket()
- if err == nil {
- t.Errorf("got a message %q after failed key change", readback)
+ close(checker.called)
+ if _, ok := <-checker.called; ok {
+ // If all went well, we registered exactly 2 key changes: one
+ // that establishes the session, and one that we requested
+ // additionally.
+ t.Fatalf("got another host key checks after 2 handshakes")
}
}
func TestForceFirstKex(t *testing.T) {
+ // like handshakePair, but must access the keyingTransport.
checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
+ clientConf := &ClientConfig{HostKeyCallback: checker.Check}
+ a, b, err := netPipe()
if err != nil {
- t.Fatalf("handshakePair: %v", err)
+ t.Fatalf("netPipe: %v", err)
}
- defer trC.Close()
- defer trS.Close()
+ var trC, trS keyingTransport
+ trC = newTransport(a, rand.Reader, true)
+
+ // This is the disallowed packet:
trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
+ // Rest of the setup.
+ trS = newTransport(b, rand.Reader, false)
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server := newServerTransport(trS, v, v, serverConf)
+
+ defer client.Close()
+ defer server.Close()
+
// We setup the initial key exchange, but the remote side
// tries to send serviceRequestMsg in cleartext, which is
// disallowed.
- err = trS.sendKexInit(firstKeyExchange)
- if err == nil {
+ if err := server.waitSession(); err == nil {
t.Errorf("server first kex init should reject unexpected packet")
}
}
-func TestHandshakeTwice(t *testing.T) {
- checker := &testChecker{}
- trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
- if err != nil {
- t.Fatalf("handshakePair: %v", err)
- }
-
- defer trC.Close()
- defer trS.Close()
-
- // Both sides should ask for the first key exchange first.
- err = trS.sendKexInit(firstKeyExchange)
- if err != nil {
- t.Errorf("server sendKexInit: %v", err)
- }
-
- err = trC.sendKexInit(firstKeyExchange)
- if err != nil {
- t.Errorf("client sendKexInit: %v", err)
- }
-
- sent := 0
- // send a packet
- packet := make([]byte, 5)
- packet[0] = msgRequestSuccess
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
- sent++
-
- // Send another packet. Use a fresh one, since writePacket destroys.
- packet = make([]byte, 5)
- packet[0] = msgRequestSuccess
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
- sent++
-
- // 2nd key change.
- err = trC.sendKexInit(subsequentKeyExchange)
- if err != nil {
- t.Errorf("sendKexInit: %v", err)
- }
-
- packet = make([]byte, 5)
- packet[0] = msgRequestSuccess
- if err := trC.writePacket(packet); err != nil {
- t.Errorf("writePacket: %v", err)
- }
- sent++
-
- packet = make([]byte, 5)
- packet[0] = msgRequestSuccess
- for i := 0; i < sent; i++ {
- msg, err := trS.readPacket()
- if err != nil {
- t.Fatalf("server closed too soon: %v", err)
- }
-
- if bytes.Compare(msg, packet) != 0 {
- t.Errorf("packet %d: got %q want %q", i, msg, packet)
- }
- }
- if len(checker.calls) != 2 {
- t.Errorf("got %d key changes, want 2", len(checker.calls))
- }
-}
-
func TestHandshakeAutoRekeyWrite(t *testing.T) {
- checker := &testChecker{}
+ checker := &syncChecker{make(chan int, 10)}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr", false)
@@ -327,12 +249,19 @@
defer trC.Close()
defer trS.Close()
+ <-checker.called
+
for i := 0; i < 5; i++ {
packet := make([]byte, 251)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
+ if i == 2 {
+ // Make sure the kex is in progress.
+ <-checker.called
+ }
+
}
j := 0
@@ -346,18 +275,14 @@
if j != 5 {
t.Errorf("got %d, want 5 messages", j)
}
-
- if len(checker.calls) != 2 {
- t.Errorf("got %d key changes, wanted 2", len(checker.calls))
- }
}
type syncChecker struct {
called chan int
}
-func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
- t.called <- 1
+func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
+ c.called <- 1
return nil
}
@@ -399,6 +324,7 @@
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
return nil
}
+
func (n *errorKeyingTransport) getSessionID() []byte {
return nil
}
@@ -425,20 +351,32 @@
func TestHandshakeErrorHandlingRead(t *testing.T) {
for i := 0; i < 20; i++ {
- testHandshakeErrorHandlingN(t, i, -1)
+ testHandshakeErrorHandlingN(t, i, -1, false)
}
}
func TestHandshakeErrorHandlingWrite(t *testing.T) {
for i := 0; i < 20; i++ {
- testHandshakeErrorHandlingN(t, -1, i)
+ testHandshakeErrorHandlingN(t, -1, i, false)
+ }
+}
+
+func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ testHandshakeErrorHandlingN(t, i, -1, true)
+ }
+}
+
+func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
+ for i := 0; i < 20; i++ {
+ testHandshakeErrorHandlingN(t, -1, i, true)
}
}
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and
// panic.
-func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
+func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
a, b := memPipe()
@@ -451,37 +389,57 @@
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
serverConn.hostKeys = []Signer{key}
go serverConn.readLoop()
+ go serverConn.kexLoop()
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
go clientConn.readLoop()
+ go clientConn.kexLoop()
var wg sync.WaitGroup
- wg.Add(4)
for _, hs := range []packetConn{serverConn, clientConn} {
- go func(c packetConn) {
- for {
- err := c.writePacket(msg)
- if err != nil {
- break
+ if !coupled {
+ wg.Add(2)
+ go func(c packetConn) {
+ for i := 0; ; i++ {
+ str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
+ err := c.writePacket(Marshal(&serviceRequestMsg{str}))
+ if err != nil {
+ break
+ }
}
- }
- wg.Done()
- }(hs)
- go func(c packetConn) {
- for {
- _, err := c.readPacket()
- if err != nil {
- break
+ wg.Done()
+ c.Close()
+ }(hs)
+ go func(c packetConn) {
+ for {
+ _, err := c.readPacket()
+ if err != nil {
+ break
+ }
}
- }
- wg.Done()
- }(hs)
- }
+ wg.Done()
+ }(hs)
+ } else {
+ wg.Add(1)
+ go func(c packetConn) {
+ for {
+ _, err := c.readPacket()
+ if err != nil {
+ break
+ }
+ if err := c.writePacket(msg); err != nil {
+ break
+ }
+ }
+ wg.Done()
+ }(hs)
+ }
+ }
wg.Wait()
}
diff --git a/ssh/server.go b/ssh/server.go
index 9037470..28b109a 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -188,7 +188,7 @@
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
- if err := s.transport.requestInitialKeyChange(); err != nil {
+ if err := s.transport.waitSession(); err != nil {
return nil, err
}
@@ -260,7 +260,7 @@
}
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
- var err error
+ sessionID := s.transport.getSessionID()
var cache pubKeyCache
var perms *Permissions
@@ -385,7 +385,7 @@
if !isAcceptableAlgo(sig.Format) {
break
}
- signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
+ signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
if err := pubKey.Verify(signedData, sig); err != nil {
return nil, err
@@ -421,12 +421,12 @@
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
- if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
+ if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
return nil, err
}
}
- if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
+ if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
return nil, err
}
return perms, nil
diff --git a/ssh/transport.go b/ssh/transport.go
index 7e43a12..fd19932 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -22,7 +22,9 @@
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
- // Read a packet from the connection
+ // Read a packet from the connection. The read is blocking,
+ // i.e. if error is nil, then the returned byte slice is
+ // always non-empty.
readPacket() ([]byte, error)
// Close closes the write-side of the connection.