ssh/terminal: support home, end, up and down keys.
R=golang-dev, dave
CC=golang-dev
https://golang.org/cl/9777043
diff --git a/ssh/terminal/terminal.go b/ssh/terminal/terminal.go
index c1ed0c0..d956b51 100644
--- a/ssh/terminal/terminal.go
+++ b/ssh/terminal/terminal.go
@@ -75,6 +75,17 @@
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
+
+ // history contains previously entered commands so that they can be
+ // accessed with the up and down keys.
+ history stRingBuffer
+ // historyIndex stores the currently accessed history entry, where zero
+ // means the immediately previous entry.
+ historyIndex int
+ // When navigating up and down the history it's possible to return to
+ // the incomplete, initial line. That value is stored in
+ // historyPending.
+ historyPending string
}
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
@@ -83,12 +94,13 @@
// "> ").
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
return &Terminal{
- Escape: &vt100EscapeCodes,
- c: c,
- prompt: prompt,
- termWidth: 80,
- termHeight: 24,
- echo: true,
+ Escape: &vt100EscapeCodes,
+ c: c,
+ prompt: prompt,
+ termWidth: 80,
+ termHeight: 24,
+ echo: true,
+ historyIndex: -1,
}
}
@@ -104,6 +116,8 @@
keyRight
keyAltLeft
keyAltRight
+ keyHome
+ keyEnd
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
@@ -130,6 +144,15 @@
}
}
+ if len(b) >= 3 && b[0] == keyEscape && b[1] == 'O' {
+ switch b[2] {
+ case 'H':
+ return keyHome, b[3:]
+ case 'F':
+ return keyEnd, b[3:]
+ }
+ }
+
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
@@ -238,6 +261,19 @@
const maxLineLength = 4096
+func (t *Terminal) setLine(newLine []byte, newPos int) {
+ if t.echo {
+ t.moveCursorToPos(0)
+ t.writeLine(newLine)
+ for i := len(newLine); i < len(t.line); i++ {
+ t.writeLine(space)
+ }
+ t.moveCursorToPos(newPos)
+ }
+ t.line = newLine
+ t.pos = newPos
+}
+
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (t *Terminal) handleKey(key int) (line string, ok bool) {
@@ -303,6 +339,42 @@
}
t.pos++
t.moveCursorToPos(t.pos)
+ case keyHome:
+ if t.pos == 0 {
+ return
+ }
+ t.pos = 0
+ t.moveCursorToPos(t.pos)
+ case keyEnd:
+ if t.pos == len(t.line) {
+ return
+ }
+ t.pos = len(t.line)
+ t.moveCursorToPos(t.pos)
+ case keyUp:
+ entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1)
+ if !ok {
+ return "", false
+ }
+ if t.historyIndex == -1 {
+ t.historyPending = string(t.line)
+ }
+ t.historyIndex++
+ t.setLine([]byte(entry), len(entry))
+ case keyDown:
+ switch t.historyIndex {
+ case -1:
+ return
+ case 0:
+ t.setLine([]byte(t.historyPending), len(t.historyPending))
+ t.historyIndex--
+ default:
+ entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1)
+ if ok {
+ t.historyIndex--
+ t.setLine([]byte(entry), len(entry))
+ }
+ }
case keyEnter:
t.moveCursorToPos(len(t.line))
t.queue([]byte("\r\n"))
@@ -320,16 +392,7 @@
t.lock.Lock()
if newLine != nil {
- if t.echo {
- t.moveCursorToPos(0)
- t.writeLine(newLine)
- for i := len(newLine); i < len(t.line); i++ {
- t.writeLine(space)
- }
- t.moveCursorToPos(newPos)
- }
- t.line = newLine
- t.pos = newPos
+ t.setLine(newLine, newPos)
return
}
}
@@ -483,6 +546,8 @@
t.c.Write(t.outBuf)
t.outBuf = t.outBuf[:0]
if lineOk {
+ t.historyIndex = -1
+ t.history.Add(line)
return
}
@@ -501,7 +566,6 @@
t.remainder = t.inBuf[:n+len(t.remainder)]
}
- panic("unreachable")
}
// SetPrompt sets the prompt to be used when reading subsequent lines.
@@ -518,3 +582,43 @@
t.termWidth, t.termHeight = width, height
}
+
+// stRingBuffer is a ring buffer of strings.
+type stRingBuffer struct {
+ // entries contains max elements.
+ entries []string
+ max int
+ // head contains the index of the element most recently added to the ring.
+ head int
+ // size contains the number of elements in the ring.
+ size int
+}
+
+func (s *stRingBuffer) Add(a string) {
+ if s.entries == nil {
+ const defaultNumEntries = 100
+ s.entries = make([]string, defaultNumEntries)
+ s.max = defaultNumEntries
+ }
+
+ s.head = (s.head + 1) % s.max
+ s.entries[s.head] = a
+ if s.size < s.max {
+ s.size++
+ }
+}
+
+// NthPreviousEntry returns the value passed to the nth previous call to Add.
+// If n is zero then the immediately prior value is returned, if one, then the
+// next most recent, and so on. If such an element doesn't exist then ok is
+// false.
+func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
+ if n >= s.size {
+ return "", false
+ }
+ index := s.head - n
+ if index < 0 {
+ index += s.max
+ }
+ return s.entries[index], true
+}
diff --git a/ssh/terminal/terminal_test.go b/ssh/terminal/terminal_test.go
index a219721..ffcda79 100644
--- a/ssh/terminal/terminal_test.go
+++ b/ssh/terminal/terminal_test.go
@@ -52,39 +52,54 @@
}
var keyPressTests = []struct {
- in string
- line string
- err error
+ in string
+ line string
+ err error
+ throwAwayLines int
}{
{
- "",
- "",
- io.EOF,
+ err: io.EOF,
},
{
- "\r",
- "",
- nil,
+ in: "\r",
+ line: "",
},
{
- "foo\r",
- "foo",
- nil,
+ in: "foo\r",
+ line: "foo",
},
{
- "a\x1b[Cb\r", // right
- "ab",
- nil,
+ in: "a\x1b[Cb\r", // right
+ line: "ab",
},
{
- "a\x1b[Db\r", // left
- "ba",
- nil,
+ in: "a\x1b[Db\r", // left
+ line: "ba",
},
{
- "a\177b\r", // backspace
- "b",
- nil,
+ in: "a\177b\r", // backspace
+ line: "b",
+ },
+ {
+ in: "\x1b[A\r", // up
+ },
+ {
+ in: "\x1b[B\r", // down
+ },
+ {
+ in: "line\x1b[A\x1b[B\r", // up then down
+ line: "line",
+ },
+ {
+ in: "line1\rline2\x1b[A\r", // recall previous line.
+ line: "line1",
+ throwAwayLines: 1,
+ },
+ {
+ // recall two previous lines and append.
+ in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r",
+ line: "line1xxx",
+ throwAwayLines: 2,
},
}
@@ -96,6 +111,12 @@
bytesPerRead: j,
}
ss := NewTerminal(c, "> ")
+ for k := 0; k < test.throwAwayLines; k++ {
+ _, err := ss.ReadLine()
+ if err != nil {
+ t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err)
+ }
+ }
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)