ssh: fix protocol version exchange (for multi-line)

Fixes golang/go#23194

During SSH Protocol Version Exchange, a client may send metadata lines
prior to sending the SSH version string. To conform to the RFC, all SSH
implementations must support this (minimally, clients can ignore the
metadata lines).

For example, this is valid:
some-metadata
SSH-2.0-OpenSSH

The current Go implementation takes the first line it sees as
the version string (in this case, some-metadata). Then, it uses
the next line (SSH-2.0-OpenSSH) as part of key exchange, which
is guaranteed to fail.

Unfortunately, this SSH feature is used by some vendors and is part
of the official RFC: https://tools.ietf.org/html/rfc4253#section-4.2

Change-Id: I7be61700a07756353875bf43aad09a580ba533ff
Reviewed-on: https://go-review.googlesource.com/86675
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/client_test.go b/ssh/client_test.go
index ef95069..81f9599 100644
--- a/ssh/client_test.go
+++ b/ssh/client_test.go
@@ -5,41 +5,77 @@
 package ssh
 
 import (
-	"net"
 	"strings"
 	"testing"
 )
 
-func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
-	clientConn, serverConn := net.Pipe()
-	defer clientConn.Close()
-	receivedVersion := make(chan string, 1)
-	config.HostKeyCallback = InsecureIgnoreHostKey()
-	go func() {
-		version, err := readVersion(serverConn)
-		if err != nil {
-			receivedVersion <- ""
-		} else {
-			receivedVersion <- string(version)
-		}
-		serverConn.Close()
-	}()
-	NewClientConn(clientConn, "", config)
-	actual := <-receivedVersion
-	if actual != expected {
-		t.Fatalf("got %s; want %s", actual, expected)
+func TestClientVersion(t *testing.T) {
+	for _, tt := range []struct {
+		name      string
+		version   string
+		multiLine string
+		wantErr   bool
+	}{
+		{
+			name:    "default version",
+			version: packageVersion,
+		},
+		{
+			name:    "custom version",
+			version: "SSH-2.0-CustomClientVersionString",
+		},
+		{
+			name:      "good multi line version",
+			version:   packageVersion,
+			multiLine: strings.Repeat("ignored\r\n", 20),
+		},
+		{
+			name:      "bad multi line version",
+			version:   packageVersion,
+			multiLine: "bad multi line version",
+			wantErr:   true,
+		},
+		{
+			name:      "long multi line version",
+			version:   packageVersion,
+			multiLine: strings.Repeat("long multi line version\r\n", 50)[:256],
+			wantErr:   true,
+		},
+	} {
+		t.Run(tt.name, func(t *testing.T) {
+			c1, c2, err := netPipe()
+			if err != nil {
+				t.Fatalf("netPipe: %v", err)
+			}
+			defer c1.Close()
+			defer c2.Close()
+			go func() {
+				if tt.multiLine != "" {
+					c1.Write([]byte(tt.multiLine))
+				}
+				NewClientConn(c1, "", &ClientConfig{
+					ClientVersion:   tt.version,
+					HostKeyCallback: InsecureIgnoreHostKey(),
+				})
+				c1.Close()
+			}()
+			conf := &ServerConfig{NoClientAuth: true}
+			conf.AddHostKey(testSigners["rsa"])
+			conn, _, _, err := NewServerConn(c2, conf)
+			if err == nil == tt.wantErr {
+				t.Fatalf("got err %v; wantErr %t", err, tt.wantErr)
+			}
+			if tt.wantErr {
+				// Don't verify the version on an expected error.
+				return
+			}
+			if got := string(conn.ClientVersion()); got != tt.version {
+				t.Fatalf("got %q; want %q", got, tt.version)
+			}
+		})
 	}
 }
 
-func TestCustomClientVersion(t *testing.T) {
-	version := "Test-Client-Version-0.0"
-	testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
-}
-
-func TestDefaultClientVersion(t *testing.T) {
-	testClientVersion(t, &ClientConfig{}, packageVersion)
-}
-
 func TestHostKeyCheck(t *testing.T) {
 	for _, tt := range []struct {
 		name      string
diff --git a/ssh/transport.go b/ssh/transport.go
index 01150eb..82da0d7 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -6,6 +6,7 @@
 
 import (
 	"bufio"
+	"bytes"
 	"errors"
 	"io"
 	"log"
@@ -342,7 +343,7 @@
 	var ok bool
 	var buf [1]byte
 
-	for len(versionString) < maxVersionStringBytes {
+	for length := 0; length < maxVersionStringBytes; length++ {
 		_, err := io.ReadFull(r, buf[:])
 		if err != nil {
 			return nil, err
@@ -350,6 +351,13 @@
 		// The RFC says that the version should be terminated with \r\n
 		// but several SSH servers actually only send a \n.
 		if buf[0] == '\n' {
+			if !bytes.HasPrefix(versionString, []byte("SSH-")) {
+				// RFC 4253 says we need to ignore all version string lines
+				// except the one containing the SSH version (provided that
+				// all the lines do not exceed 255 bytes in total).
+				versionString = versionString[:0]
+				continue
+			}
 			ok = true
 			break
 		}
diff --git a/ssh/transport_test.go b/ssh/transport_test.go
index 92d83ab..8445e1e 100644
--- a/ssh/transport_test.go
+++ b/ssh/transport_test.go
@@ -13,11 +13,13 @@
 )
 
 func TestReadVersion(t *testing.T) {
-	longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n"
 	cases := map[string]string{
 		"SSH-2.0-bla\r\n":    "SSH-2.0-bla",
 		"SSH-2.0-bla\n":      "SSH-2.0-bla",
-		longversion + "\r\n": longversion,
+		multiLineVersion:     "SSH-2.0-bla",
+		longVersion + "\r\n": longVersion,
 	}
 
 	for in, want := range cases {
@@ -33,9 +35,11 @@
 }
 
 func TestReadVersionError(t *testing.T) {
-	longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
+	multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n"
 	cases := []string{
-		longversion + "too-long\r\n",
+		longVersion + "too-long\r\n",
+		multiLineVersion,
 	}
 	for _, in := range cases {
 		if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
@@ -60,7 +64,7 @@
 func TestExchangeVersions(t *testing.T) {
 	cases := []string{
 		"not\x000allowed",
-		"not allowed\n",
+		"not allowed\x01\r\n",
 	}
 	for _, c := range cases {
 		buf := bytes.NewBufferString("SSH-2.0-bla\r\n")