ssh/terminal: fix line endings handling in ReadPassword

Fixes golang/go#16552

Change-Id: I18a9c9b42fe042c4871b3efb3f51bef7cca335d0
Reviewed-on: https://go-review.googlesource.com/25355
Reviewed-by: Adam Langley <alangley@gmail.com>
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/terminal.go b/terminal.go
index 0c5bd56..d35e8b1 100644
--- a/terminal.go
+++ b/terminal.go
@@ -920,3 +920,31 @@
 	}
 	return s.entries[index], true
 }
+
+// readPasswordLine reads from reader until it finds \n or io.EOF.
+// The slice returned does not include the \n.
+// readPasswordLine also ignores any \r it finds.
+func readPasswordLine(reader io.Reader) ([]byte, error) {
+	var buf [1]byte
+	var ret []byte
+
+	for {
+		n, err := reader.Read(buf[:])
+		if err != nil {
+			if err == io.EOF && len(ret) > 0 {
+				return ret, nil
+			}
+			return ret, err
+		}
+		if n > 0 {
+			switch buf[0] {
+			case '\n':
+				return ret, nil
+			case '\r':
+				// remove \r from passwords on Windows
+			default:
+				ret = append(ret, buf[0])
+			}
+		}
+	}
+}
diff --git a/terminal_test.go b/terminal_test.go
index 1d54c4f..901c72a 100644
--- a/terminal_test.go
+++ b/terminal_test.go
@@ -270,6 +270,50 @@
 	}
 }
 
+func TestReadPasswordLineEnd(t *testing.T) {
+	var tests = []struct {
+		input string
+		want  string
+	}{
+		{"\n", ""},
+		{"\r\n", ""},
+		{"test\r\n", "test"},
+		{"testtesttesttes\n", "testtesttesttes"},
+		{"testtesttesttes\r\n", "testtesttesttes"},
+		{"testtesttesttesttest\n", "testtesttesttesttest"},
+		{"testtesttesttesttest\r\n", "testtesttesttesttest"},
+	}
+	for _, test := range tests {
+		buf := new(bytes.Buffer)
+		if _, err := buf.WriteString(test.input); err != nil {
+			t.Fatal(err)
+		}
+
+		have, err := readPasswordLine(buf)
+		if err != nil {
+			t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
+			continue
+		}
+		if string(have) != test.want {
+			t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
+			continue
+		}
+
+		if _, err = buf.WriteString(test.input); err != nil {
+			t.Fatal(err)
+		}
+		have, err = readPasswordLine(buf)
+		if err != nil {
+			t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
+			continue
+		}
+		if string(have) != test.want {
+			t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
+			continue
+		}
+	}
+}
+
 func TestMakeRawState(t *testing.T) {
 	fd := int(os.Stdout.Fd())
 	if !IsTerminal(fd) {
diff --git a/util.go b/util.go
index 747f1b8..d019196 100644
--- a/util.go
+++ b/util.go
@@ -17,7 +17,6 @@
 package terminal // import "golang.org/x/crypto/ssh/terminal"
 
 import (
-	"io"
 	"syscall"
 	"unsafe"
 )
@@ -88,6 +87,13 @@
 	return int(dimensions[1]), int(dimensions[0]), nil
 }
 
+// passwordReader is an io.Reader that reads from a specific file descriptor.
+type passwordReader int
+
+func (r passwordReader) Read(buf []byte) (int, error) {
+	return syscall.Read(int(r), buf)
+}
+
 // ReadPassword reads a line of input from a terminal without local echo.  This
 // is commonly used for inputting passwords and other sensitive data. The slice
 // returned does not include the \n.
@@ -109,27 +115,5 @@
 		syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
 	}()
 
-	var buf [16]byte
-	var ret []byte
-	for {
-		n, err := syscall.Read(fd, buf[:])
-		if err != nil {
-			return nil, err
-		}
-		if n == 0 {
-			if len(ret) == 0 {
-				return nil, io.EOF
-			}
-			break
-		}
-		if buf[n-1] == '\n' {
-			n--
-		}
-		ret = append(ret, buf[:n]...)
-		if n < len(buf) {
-			break
-		}
-	}
-
-	return ret, nil
+	return readPasswordLine(passwordReader(fd))
 }
diff --git a/util_windows.go b/util_windows.go
index ae9fa9e..e0a1f36 100644
--- a/util_windows.go
+++ b/util_windows.go
@@ -17,7 +17,6 @@
 package terminal
 
 import (
-	"io"
 	"syscall"
 	"unsafe"
 )
@@ -123,6 +122,13 @@
 	return int(info.size.x), int(info.size.y), nil
 }
 
+// passwordReader is an io.Reader that reads from a specific Windows HANDLE.
+type passwordReader int
+
+func (r passwordReader) Read(buf []byte) (int, error) {
+	return syscall.Read(syscall.Handle(r), buf)
+}
+
 // ReadPassword reads a line of input from a terminal without local echo.  This
 // is commonly used for inputting passwords and other sensitive data. The slice
 // returned does not include the \n.
@@ -145,30 +151,5 @@
 		syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
 	}()
 
-	var buf [16]byte
-	var ret []byte
-	for {
-		n, err := syscall.Read(syscall.Handle(fd), buf[:])
-		if err != nil {
-			return nil, err
-		}
-		if n == 0 {
-			if len(ret) == 0 {
-				return nil, io.EOF
-			}
-			break
-		}
-		if buf[n-1] == '\n' {
-			n--
-		}
-		if n > 0 && buf[n-1] == '\r' {
-			n--
-		}
-		ret = append(ret, buf[:n]...)
-		if n < len(buf) {
-			break
-		}
-	}
-
-	return ret, nil
+	return readPasswordLine(passwordReader(fd))
 }