ssh: defer channel window adjustment

Sending a window adjustment after every read is unnecessarily chatty,
especially with a series of small reads like with TTY interactions.

Copy OpenSSH's logic for deferring these, which seemingly hasn't changed
since 2007. Note that since channelWindowSize and c.maxIncomingPayload
are currently constants here, the two checks could be combined into a
single check for c.myWindow < 2 MiB - 96 KiB (with the current values
of the constants).

Fixes golang/go#57424.

Change-Id: Ifcef5be76fcc3f0b1a6dc396096bed9c50d64f21
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/459915
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Run-TryBot: Nicola Murino <nicola.murino@gmail.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Commit-Queue: Nicola Murino <nicola.murino@gmail.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/ssh/channel.go b/ssh/channel.go
index c0834c0..cc0bb7a 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -187,9 +187,11 @@
 	pending    *buffer
 	extPending *buffer
 
-	// windowMu protects myWindow, the flow-control window.
-	windowMu sync.Mutex
-	myWindow uint32
+	// windowMu protects myWindow, the flow-control window, and myConsumed,
+	// the number of bytes consumed since we last increased myWindow
+	windowMu   sync.Mutex
+	myWindow   uint32
+	myConsumed uint32
 
 	// writeMu serializes calls to mux.conn.writePacket() and
 	// protects sentClose and packetPool. This mutex must be
@@ -332,14 +334,24 @@
 	return nil
 }
 
-func (c *channel) adjustWindow(n uint32) error {
+func (c *channel) adjustWindow(adj uint32) error {
 	c.windowMu.Lock()
-	// Since myWindow is managed on our side, and can never exceed
-	// the initial window setting, we don't worry about overflow.
-	c.myWindow += uint32(n)
+	// Since myConsumed and myWindow are managed on our side, and can never
+	// exceed the initial window setting, we don't worry about overflow.
+	c.myConsumed += adj
+	var sendAdj uint32
+	if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
+		(c.myWindow < channelWindowSize/2) {
+		sendAdj = c.myConsumed
+		c.myConsumed = 0
+		c.myWindow += sendAdj
+	}
 	c.windowMu.Unlock()
+	if sendAdj == 0 {
+		return nil
+	}
 	return c.sendMessage(windowAdjustMsg{
-		AdditionalBytes: uint32(n),
+		AdditionalBytes: sendAdj,
 	})
 }
 
diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go
index 8697cd6..f27339c 100644
--- a/ssh/mempipe_test.go
+++ b/ssh/mempipe_test.go
@@ -13,9 +13,10 @@
 // An in-memory packetConn. It is safe to call Close and writePacket
 // from different goroutines.
 type memTransport struct {
-	eof     bool
-	pending [][]byte
-	write   *memTransport
+	eof        bool
+	pending    [][]byte
+	write      *memTransport
+	writeCount uint64
 	sync.Mutex
 	*sync.Cond
 }
@@ -63,9 +64,16 @@
 	copy(c, p)
 	t.write.pending = append(t.write.pending, c)
 	t.write.Cond.Signal()
+	t.writeCount++
 	return nil
 }
 
+func (t *memTransport) getWriteCount() uint64 {
+	t.write.Lock()
+	defer t.write.Unlock()
+	return t.writeCount
+}
+
 func memPipe() (a, b packetConn) {
 	t1 := memTransport{}
 	t2 := memTransport{}
@@ -81,6 +89,9 @@
 	if err := a.writePacket([]byte{42}); err != nil {
 		t.Fatalf("writePacket: %v", err)
 	}
+	if wc := a.(*memTransport).getWriteCount(); wc != 1 {
+		t.Fatalf("got %v, want 1", wc)
+	}
 	if err := a.Close(); err != nil {
 		t.Fatal("Close: ", err)
 	}
@@ -95,6 +106,9 @@
 	if err != io.EOF {
 		t.Fatalf("got %v, %v, want EOF", p, err)
 	}
+	if wc := b.(*memTransport).getWriteCount(); wc != 0 {
+		t.Fatalf("got %v, want 0", wc)
+	}
 }
 
 func TestDoubleClose(t *testing.T) {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index eae637d..21f0ac3 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -182,6 +182,40 @@
 	}
 }
 
+func TestMuxChannelReadUnblock(t *testing.T) {
+	reader, writer, mux := channelPair(t)
+	defer reader.Close()
+	defer writer.Close()
+	defer mux.Close()
+
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+			t.Errorf("could not fill window: %v", err)
+		}
+		if _, err := writer.Write(make([]byte, 1)); err != nil {
+			t.Errorf("Write: %v", err)
+		}
+		writer.Close()
+	}()
+
+	writer.remoteWin.waitWriterBlocked()
+
+	buf := make([]byte, 32768)
+	for {
+		_, err := reader.Read(buf)
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			t.Fatalf("Read: %v", err)
+		}
+	}
+}
+
 func TestMuxChannelCloseWriteUnblock(t *testing.T) {
 	reader, writer, mux := channelPair(t)
 	defer reader.Close()
@@ -754,6 +788,43 @@
 	}
 }
 
+func TestMuxChannelWindowDeferredUpdates(t *testing.T) {
+	s, c, mux := channelPair(t)
+	cTransport := mux.conn.(*memTransport)
+	defer s.Close()
+	defer c.Close()
+	defer mux.Close()
+
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+
+	data := make([]byte, 1024)
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		_, err := s.Write(data)
+		if err != nil {
+			t.Errorf("Write: %v", err)
+			return
+		}
+	}()
+	cWritesInit := cTransport.getWriteCount()
+	buf := make([]byte, 1)
+	for i := 0; i < len(data); i++ {
+		n, err := c.Read(buf)
+		if n != len(buf) || err != nil {
+			t.Fatalf("Read: %v, %v", n, err)
+		}
+	}
+	cWrites := cTransport.getWriteCount() - cWritesInit
+	// reading 1 KiB should not cause any window updates to be sent, but allow
+	// for some unexpected writes
+	if cWrites > 30 {
+		t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites)
+	}
+}
+
 // Don't ship code with debug=true.
 func TestDebug(t *testing.T) {
 	if debugMux {