go.crypto/ssh: introduce a circular buffer for chanReader
R=agl, gustav.paul, kardianos
CC=golang-dev
https://golang.org/cl/6207051
diff --git a/ssh/buffer.go b/ssh/buffer.go
new file mode 100644
index 0000000..b089089
--- /dev/null
+++ b/ssh/buffer.go
@@ -0,0 +1,96 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "io"
+ "sync"
+)
+
+// buffer provides a linked list buffer for data exchange
+// between producer and consumer. Theoretically the buffer is
+// of unlimited capacity as it does no allocation of its own.
+type buffer struct {
+ // protects concurrent access to head, tail and closed
+ *sync.Cond
+
+ head *element // the buffer that will be read first
+ tail *element // the buffer that will be read last
+
+ closed bool
+}
+
+// An element represents a single link in a linked list.
+type element struct {
+ buf []byte
+ next *element
+}
+
+// newBuffer returns an empty buffer that is not closed.
+func newBuffer() *buffer {
+ e := new(element)
+ b := &buffer{
+ Cond: newCond(),
+ head: e,
+ tail: e,
+ }
+ return b
+}
+
+// write makes buf available for Read to receive.
+// buf must not be modified after the call to write.
+func (b *buffer) write(buf []byte) {
+ b.Cond.L.Lock()
+ defer b.Cond.L.Unlock()
+ e := &element{buf: buf}
+ b.tail.next = e
+ b.tail = e
+ b.Cond.Signal()
+}
+
+// eof closes the buffer. Reads from the buffer once all
+// the data has been consumed will receive os.EOF.
+func (b *buffer) eof() error {
+ b.Cond.L.Lock()
+ defer b.Cond.L.Unlock()
+ b.closed = true
+ b.Cond.Signal()
+ return nil
+}
+
+// Read reads data from the internal buffer in buf.
+// Reads will block if no data is available, or until
+// the buffer is closed.
+func (b *buffer) Read(buf []byte) (n int, err error) {
+ b.Cond.L.Lock()
+ defer b.Cond.L.Unlock()
+ for len(buf) > 0 {
+ // if there is data in b.head, copy it
+ if len(b.head.buf) > 0 {
+ r := copy(buf, b.head.buf)
+ buf, b.head.buf = buf[r:], b.head.buf[r:]
+ n += r
+ continue
+ }
+ // if there is a next buffer, make it the head
+ if len(b.head.buf) == 0 && b.head != b.tail {
+ b.head = b.head.next
+ continue
+ }
+ // if at least one byte has been copied, return
+ if n > 0 {
+ break
+ }
+ // if nothing was read, and there is nothing outstanding
+ // check to see if the buffer is closed.
+ if b.closed {
+ err = io.EOF
+ break
+ }
+ // out of buffers, wait for producer
+ b.Cond.Wait()
+ }
+ return
+}
diff --git a/ssh/buffer_test.go b/ssh/buffer_test.go
new file mode 100644
index 0000000..135c4ae
--- /dev/null
+++ b/ssh/buffer_test.go
@@ -0,0 +1,87 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "io"
+ "testing"
+)
+
+var BYTES = []byte("abcdefghijklmnopqrstuvwxyz")
+
+func TestBufferReadwrite(t *testing.T) {
+ b := newBuffer()
+ b.write(BYTES[:10])
+ r, _ := b.Read(make([]byte, 10))
+ if r != 10 {
+ t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
+ }
+
+ b = newBuffer()
+ b.write(BYTES[:5])
+ r, _ = b.Read(make([]byte, 10))
+ if r != 5 {
+ t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
+ }
+
+ b = newBuffer()
+ b.write(BYTES[:10])
+ r, _ = b.Read(make([]byte, 5))
+ if r != 5 {
+ t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
+ }
+
+ b = newBuffer()
+ b.write(BYTES[:5])
+ b.write(BYTES[5:15])
+ r, _ = b.Read(make([]byte, 10))
+ r2, _ := b.Read(make([]byte, 10))
+ if r != 10 || r2 != 5 || 15 != r+r2 {
+ t.Fatal("Expected written == read == 15")
+ }
+}
+
+func TestBufferClose(t *testing.T) {
+ b := newBuffer()
+ b.write(BYTES[:10])
+ b.eof()
+ _, err := b.Read(make([]byte, 5))
+ if err != nil {
+ t.Fatal("expected read of 5 to not return EOF")
+ }
+ b = newBuffer()
+ b.write(BYTES[:10])
+ b.eof()
+ r, err := b.Read(make([]byte, 5))
+ r2, err2 := b.Read(make([]byte, 10))
+ if r != 5 || r2 != 5 || err != nil || err2 != nil {
+ t.Fatal("expected reads of 5 and 5")
+ }
+
+ b = newBuffer()
+ b.write(BYTES[:10])
+ b.eof()
+ r, err = b.Read(make([]byte, 5))
+ r2, err2 = b.Read(make([]byte, 10))
+ r3, err3 := b.Read(make([]byte, 10))
+ if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF {
+ t.Fatal("expected reads of 5 and 5 and 0, with EOF")
+ }
+
+ b = newBuffer()
+ b.write(make([]byte, 5))
+ b.write(make([]byte, 10))
+ b.eof()
+ r, err = b.Read(make([]byte, 9))
+ r2, err2 = b.Read(make([]byte, 3))
+ r3, err3 = b.Read(make([]byte, 3))
+ r4, err4 := b.Read(make([]byte, 10))
+ if err != nil || err2 != nil || err3 != nil || err4 != io.EOF {
+ t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4)
+ }
+ if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 {
+ t.Fatal("Expected written == read == 15", r, r2, r3, r4)
+ }
+}
diff --git a/ssh/client.go b/ssh/client.go
index 1caf678..54f68df 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -224,7 +224,7 @@
if length != uint32(len(packet)) {
return
}
- c.getChan(remoteId).stdout.handleData(packet)
+ c.getChan(remoteId).stdout.write(packet)
case msgChannelExtendedData:
if len(packet) < 13 {
// malformed data packet
@@ -242,7 +242,7 @@
// for stderr on interactive sessions. Other data types are
// silently discarded.
if datatype == 1 {
- c.getChan(remoteId).stderr.handleData(packet)
+ c.getChan(remoteId).stderr.write(packet)
}
default:
switch msg := decode(packet).(type) {
@@ -448,12 +448,12 @@
channel: &c.channel,
}
c.stdout = &chanReader{
- data: make(chan []byte, 16),
channel: &c.channel,
+ buffer: newBuffer(),
}
c.stderr = &chanReader{
- data: make(chan []byte, 16),
channel: &c.channel,
+ buffer: newBuffer(),
}
return c
}
@@ -579,44 +579,18 @@
// A chanReader represents stdout or stderr of a remote process.
type chanReader struct {
- // TODO(dfc) a fixed size channel may not be the right data structure.
- // If writes to this channel block, they will block mainLoop, making
- // it unable to receive new messages from the remote side.
- data chan []byte // receives data from remote
- dataClosed bool // protects data from being closed twice
- *channel // the channel backing this reader
- buf []byte
-}
-
-// eof signals to the consumer that there is no more data to be received.
-func (r *chanReader) eof() {
- if !r.dataClosed {
- r.dataClosed = true
- close(r.data)
- }
-}
-
-// handleData sends buf to the reader's consumer. If r.data is closed
-// the data will be silently discarded
-func (r *chanReader) handleData(buf []byte) {
- if !r.dataClosed {
- r.data <- buf
- }
+ *channel // the channel backing this reader
+ *buffer
}
// Read reads data from the remote process's stdout or stderr.
-func (r *chanReader) Read(data []byte) (int, error) {
- var ok bool
- for {
- if len(r.buf) > 0 {
- n := copy(data, r.buf)
- r.buf = r.buf[n:]
- return n, r.sendWindowAdj(n)
+func (r *chanReader) Read(buf []byte) (int, error) {
+ n, err := r.buffer.Read(buf)
+ if err != nil {
+ if err == io.EOF {
+ return n, err
}
- r.buf, ok = <-r.data
- if !ok {
- return 0, io.EOF
- }
+ return 0, err
}
- panic("unreachable")
+ return n, r.sendWindowAdj(n)
}