ssh: return session ID in ConnMeta.SessionID.
SessionID() returned nil previously.
Fixes #9761.
Change-Id: I53d2b347571d21eab2d913c2228e85997a84f757
Reviewed-on: https://go-review.googlesource.com/3872
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/client.go b/ssh/client.go
index e607610..72bd27f 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -105,6 +105,10 @@
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
+
+ // We just did the key change, so the session ID is established.
+ c.sessionID = c.transport.getSessionID()
+
return c.clientAuthenticate(config)
}
diff --git a/ssh/server.go b/ssh/server.go
index ee2eeab..52aee11 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -192,6 +192,9 @@
return nil, unexpectedMessageError(msgNewKeys, packet[0])
}
+ // We just did the key change, so the session ID is established.
+ s.sessionID = s.transport.getSessionID()
+
var packet []byte
if packet, err = s.transport.readPacket(); err != nil {
return nil, err
diff --git a/ssh/session_test.go b/ssh/session_test.go
index fce9868..88e66bf 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -626,3 +626,55 @@
t.Errorf("handler write error: %v", err)
}
}
+
+func TestSessionID(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverID := make(chan []byte, 1)
+ clientID := make(chan []byte, 1)
+
+ serverConf := &ServerConfig{
+ NoClientAuth: true,
+ }
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ clientConf := &ClientConfig{
+ User: "user",
+ }
+
+ go func() {
+ conn, chans, reqs, err := NewServerConn(c1, serverConf)
+ if err != nil {
+ t.Fatalf("server handshake: %v", err)
+ }
+ serverID <- conn.SessionID()
+ go DiscardRequests(reqs)
+ for ch := range chans {
+ ch.Reject(Prohibited, "")
+ }
+ }()
+
+ go func() {
+ conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
+ if err != nil {
+ t.Fatalf("client handshake: %v", err)
+ }
+ clientID <- conn.SessionID()
+ go DiscardRequests(reqs)
+ for ch := range chans {
+ ch.Reject(Prohibited, "")
+ }
+ }()
+
+ s := <-serverID
+ c := <-clientID
+ if bytes.Compare(s, c) != 0 {
+ t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
+ } else if len(s) == 0 {
+ t.Errorf("client and server SessionID were empty.")
+ }
+}
diff --git a/ssh/transport.go b/ssh/transport.go
index 4f68b04..0f19478 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -44,13 +44,13 @@
sessionID []byte
}
+// getSessionID returns the ID of the SSH connection. The return value
+// should not be modified.
func (t *transport) getSessionID() []byte {
if t.sessionID == nil {
panic("session ID not set yet")
}
- s := make([]byte, len(t.sessionID))
- copy(s, t.sessionID)
- return s
+ return t.sessionID
}
// packetCipher represents a combination of SSH encryption/MAC