ssh/terminal: fix line endings handling in ReadPassword Fixes golang/go#16552 Change-Id: I18a9c9b42fe042c4871b3efb3f51bef7cca335d0 Reviewed-on: https://go-review.googlesource.com/25355 Reviewed-by: Adam Langley <alangley@gmail.com> Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/terminal.go b/terminal.go index 0c5bd56..d35e8b1 100644 --- a/terminal.go +++ b/terminal.go
@@ -920,3 +920,31 @@ } return s.entries[index], true } + +// readPasswordLine reads from reader until it finds \n or io.EOF. +// The slice returned does not include the \n. +// readPasswordLine also ignores any \r it finds. +func readPasswordLine(reader io.Reader) ([]byte, error) { + var buf [1]byte + var ret []byte + + for { + n, err := reader.Read(buf[:]) + if err != nil { + if err == io.EOF && len(ret) > 0 { + return ret, nil + } + return ret, err + } + if n > 0 { + switch buf[0] { + case '\n': + return ret, nil + case '\r': + // remove \r from passwords on Windows + default: + ret = append(ret, buf[0]) + } + } + } +}
diff --git a/terminal_test.go b/terminal_test.go index 1d54c4f..901c72a 100644 --- a/terminal_test.go +++ b/terminal_test.go
@@ -270,6 +270,50 @@ } } +func TestReadPasswordLineEnd(t *testing.T) { + var tests = []struct { + input string + want string + }{ + {"\n", ""}, + {"\r\n", ""}, + {"test\r\n", "test"}, + {"testtesttesttes\n", "testtesttesttes"}, + {"testtesttesttes\r\n", "testtesttesttes"}, + {"testtesttesttesttest\n", "testtesttesttesttest"}, + {"testtesttesttesttest\r\n", "testtesttesttesttest"}, + } + for _, test := range tests { + buf := new(bytes.Buffer) + if _, err := buf.WriteString(test.input); err != nil { + t.Fatal(err) + } + + have, err := readPasswordLine(buf) + if err != nil { + t.Errorf("readPasswordLine(%q) failed: %v", test.input, err) + continue + } + if string(have) != test.want { + t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want) + continue + } + + if _, err = buf.WriteString(test.input); err != nil { + t.Fatal(err) + } + have, err = readPasswordLine(buf) + if err != nil { + t.Errorf("readPasswordLine(%q) failed: %v", test.input, err) + continue + } + if string(have) != test.want { + t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want) + continue + } + } +} + func TestMakeRawState(t *testing.T) { fd := int(os.Stdout.Fd()) if !IsTerminal(fd) {
diff --git a/util.go b/util.go index 747f1b8..d019196 100644 --- a/util.go +++ b/util.go
@@ -17,7 +17,6 @@ package terminal // import "golang.org/x/crypto/ssh/terminal" import ( - "io" "syscall" "unsafe" ) @@ -88,6 +87,13 @@ return int(dimensions[1]), int(dimensions[0]), nil } +// passwordReader is an io.Reader that reads from a specific file descriptor. +type passwordReader int + +func (r passwordReader) Read(buf []byte) (int, error) { + return syscall.Read(int(r), buf) +} + // ReadPassword reads a line of input from a terminal without local echo. This // is commonly used for inputting passwords and other sensitive data. The slice // returned does not include the \n. @@ -109,27 +115,5 @@ syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) }() - var buf [16]byte - var ret []byte - for { - n, err := syscall.Read(fd, buf[:]) - if err != nil { - return nil, err - } - if n == 0 { - if len(ret) == 0 { - return nil, io.EOF - } - break - } - if buf[n-1] == '\n' { - n-- - } - ret = append(ret, buf[:n]...) - if n < len(buf) { - break - } - } - - return ret, nil + return readPasswordLine(passwordReader(fd)) }
diff --git a/util_windows.go b/util_windows.go index ae9fa9e..e0a1f36 100644 --- a/util_windows.go +++ b/util_windows.go
@@ -17,7 +17,6 @@ package terminal import ( - "io" "syscall" "unsafe" ) @@ -123,6 +122,13 @@ return int(info.size.x), int(info.size.y), nil } +// passwordReader is an io.Reader that reads from a specific Windows HANDLE. +type passwordReader int + +func (r passwordReader) Read(buf []byte) (int, error) { + return syscall.Read(syscall.Handle(r), buf) +} + // ReadPassword reads a line of input from a terminal without local echo. This // is commonly used for inputting passwords and other sensitive data. The slice // returned does not include the \n. @@ -145,30 +151,5 @@ syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) }() - var buf [16]byte - var ret []byte - for { - n, err := syscall.Read(syscall.Handle(fd), buf[:]) - if err != nil { - return nil, err - } - if n == 0 { - if len(ret) == 0 { - return nil, io.EOF - } - break - } - if buf[n-1] == '\n' { - n-- - } - if n > 0 && buf[n-1] == '\r' { - n-- - } - ret = append(ret, buf[:n]...) - if n < len(buf) { - break - } - } - - return ret, nil + return readPasswordLine(passwordReader(fd)) }