ssh: eliminate some goroutine leaks in tests and examples

This should fix the "Log in goroutine" panic seen in
https://build.golang.org/log/e42bf69fc002113dbccfe602a6c67fd52e8f31df,
as well as a few other related leaks. It also helps to verify that
none of the functions under test deadlock unexpectedly.

See https://go.dev/wiki/CodeReviewComments#goroutine-lifetimes.

Updates golang/go#58901.

Change-Id: Ica943444db381ae1accb80b101ea646e28ebf4f9
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/541095
Auto-Submit: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/ssh/example_test.go b/ssh/example_test.go
index 0a6b076..3920832 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -16,6 +16,7 @@
 	"os"
 	"path/filepath"
 	"strings"
+	"sync"
 
 	"golang.org/x/crypto/ssh"
 	"golang.org/x/crypto/ssh/terminal"
@@ -98,8 +99,15 @@
 	}
 	log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
 
+	var wg sync.WaitGroup
+	defer wg.Wait()
+
 	// The incoming Request channel must be serviced.
-	go ssh.DiscardRequests(reqs)
+	wg.Add(1)
+	go func() {
+		ssh.DiscardRequests(reqs)
+		wg.Done()
+	}()
 
 	// Service the incoming Channel channel.
 	for newChannel := range chans {
@@ -119,16 +127,22 @@
 		// Sessions have out-of-band requests such as "shell",
 		// "pty-req" and "env".  Here we handle only the
 		// "shell" request.
+		wg.Add(1)
 		go func(in <-chan *ssh.Request) {
 			for req := range in {
 				req.Reply(req.Type == "shell", nil)
 			}
+			wg.Done()
 		}(requests)
 
 		term := terminal.NewTerminal(channel, "> ")
 
+		wg.Add(1)
 		go func() {
-			defer channel.Close()
+			defer func() {
+				channel.Close()
+				wg.Done()
+			}()
 			for {
 				line, err := term.ReadLine()
 				if err != nil {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index 1db3be5..eae637d 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -10,7 +10,6 @@
 	"io"
 	"sync"
 	"testing"
-	"time"
 )
 
 func muxPair() (*mux, *mux) {
@@ -112,7 +111,11 @@
 
 	magic := "hello world"
 	magicExt := "hello stderr"
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		_, err := s.Write([]byte(magic))
 		if err != nil {
 			t.Errorf("Write: %v", err)
@@ -152,13 +155,15 @@
 	defer writer.Close()
 	defer mux.Close()
 
-	wDone := make(chan int, 1)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
 			t.Errorf("could not fill window: %v", err)
 		}
 		writer.Write(make([]byte, 1))
-		wDone <- 1
 	}()
 	writer.remoteWin.waitWriterBlocked()
 
@@ -175,7 +180,6 @@
 	if _, err := reader.SendRequest("hello", true, nil); err == nil {
 		t.Errorf("SendRequest succeeded.")
 	}
-	<-wDone
 }
 
 func TestMuxChannelCloseWriteUnblock(t *testing.T) {
@@ -184,20 +188,21 @@
 	defer writer.Close()
 	defer mux.Close()
 
-	wDone := make(chan int, 1)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
 			t.Errorf("could not fill window: %v", err)
 		}
 		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
 			t.Errorf("got %v, want EOF for unblock write", err)
 		}
-		wDone <- 1
 	}()
 
 	writer.remoteWin.waitWriterBlocked()
 	reader.Close()
-	<-wDone
 }
 
 func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
@@ -206,20 +211,21 @@
 	defer writer.Close()
 	defer mux.Close()
 
-	wDone := make(chan int, 1)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
 			t.Errorf("could not fill window: %v", err)
 		}
 		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
 			t.Errorf("got %v, want EOF for unblock write", err)
 		}
-		wDone <- 1
 	}()
 
 	writer.remoteWin.waitWriterBlocked()
 	mux.Close()
-	<-wDone
 }
 
 func TestMuxReject(t *testing.T) {
@@ -227,7 +233,12 @@
 	defer server.Close()
 	defer client.Close()
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
+
 		ch, ok := <-server.incomingChannels
 		if !ok {
 			t.Error("cannot accept channel")
@@ -267,6 +278,7 @@
 
 	var received int
 	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
 	wg.Add(1)
 	go func() {
 		for r := range server.incomingRequests {
@@ -295,7 +307,6 @@
 	}
 	if ok {
 		t.Errorf("SendRequest(no): %v", ok)
-
 	}
 
 	client.Close()
@@ -389,13 +400,8 @@
 
 	// Wait for the server to send the keepalive message and receive back a
 	// response.
-	select {
-	case err := <-kDone:
-		if err != nil {
-			t.Fatal(err)
-		}
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never received ack")
+	if err := <-kDone; err != nil {
+		t.Fatal(err)
 	}
 
 	// Confirm client hasn't closed.
@@ -403,13 +409,9 @@
 		t.Fatalf("failed to send keepalive: %v", err)
 	}
 
-	select {
-	case err := <-kDone:
-		if err != nil {
-			t.Fatal(err)
-		}
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never shut down")
+	// Wait for the server to shut down.
+	if err := <-kDone; err != nil {
+		t.Fatal(err)
 	}
 }
 
@@ -525,11 +527,7 @@
 	defer ch.Close()
 
 	// Wait for the server to close the channel and send the keepalive.
-	select {
-	case <-kDone:
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never received ack")
-	}
+	<-kDone
 
 	// Make sure the channel closed.
 	if _, ok := <-ch.incomingRequests; ok {
@@ -541,22 +539,29 @@
 		t.Fatalf("failed to send keepalive: %v", err)
 	}
 
-	select {
-	case <-kDone:
-	case <-time.After(10 * time.Second):
-		t.Fatalf("server never shut down")
-	}
+	// Wait for the server to shut down.
+	<-kDone
 }
 
 func TestMuxGlobalRequest(t *testing.T) {
+	var sawPeek bool
+	var wg sync.WaitGroup
+	defer func() {
+		wg.Wait()
+		if !sawPeek {
+			t.Errorf("never saw 'peek' request")
+		}
+	}()
+
 	clientMux, serverMux := muxPair()
 	defer serverMux.Close()
 	defer clientMux.Close()
 
-	var seen bool
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		for r := range serverMux.incomingRequests {
-			seen = seen || r.Type == "peek"
+			sawPeek = sawPeek || r.Type == "peek"
 			if r.WantReply {
 				err := r.Reply(r.Type == "yes",
 					append([]byte(r.Type), r.Payload...))
@@ -586,10 +591,6 @@
 		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
 			ok, data, err)
 	}
-
-	if !seen {
-		t.Errorf("never saw 'peek' request")
-	}
 }
 
 func TestMuxGlobalRequestUnblock(t *testing.T) {
@@ -739,7 +740,13 @@
 		t.Errorf("could not send packet")
 	}
 
-	go a.SendRequest("hello", false, nil)
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
+	go func() {
+		a.SendRequest("hello", false, nil)
+		wg.Done()
+	}()
 
 	_, ok := <-b.incomingRequests
 	if ok {
diff --git a/ssh/session_test.go b/ssh/session_test.go
index 521677f..807a913 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -13,6 +13,7 @@
 	"io"
 	"math/rand"
 	"net"
+	"sync"
 	"testing"
 
 	"golang.org/x/crypto/ssh/terminal"
@@ -27,8 +28,14 @@
 		t.Fatalf("netPipe: %v", err)
 	}
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
-		defer c1.Close()
+		defer func() {
+			c1.Close()
+			wg.Done()
+		}()
 		conf := ServerConfig{
 			NoClientAuth: true,
 		}
@@ -39,7 +46,11 @@
 			t.Errorf("Unable to handshake: %v", err)
 			return
 		}
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 
 		for newCh := range chans {
 			if newCh.ChannelType() != "session" {
@@ -52,8 +63,10 @@
 				t.Errorf("Accept: %v", err)
 				continue
 			}
+			wg.Add(1)
 			go func() {
 				handler(ch, inReqs, t)
+				wg.Done()
 			}()
 		}
 		if err := conn.Wait(); err != io.EOF {
@@ -338,8 +351,13 @@
 		t.Fatal(err)
 	}
 	defer session.Close()
-	result := make(chan []byte)
 
+	serverStdin, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("StdinPipe failed: %v", err)
+	}
+
+	result := make(chan []byte)
 	go func() {
 		defer close(result)
 		echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
@@ -355,10 +373,6 @@
 		result <- echoedBuf.Bytes()
 	}()
 
-	serverStdin, err := session.StdinPipe()
-	if err != nil {
-		t.Fatalf("StdinPipe failed: %v", err)
-	}
 	written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
 	if err != nil {
 		t.Errorf("failed to copy origBuf to serverStdin: %v", err)
@@ -648,29 +662,44 @@
 		User:            "user",
 	}
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+
 	srvErrCh := make(chan error, 1)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		conn, chans, reqs, err := NewServerConn(c1, serverConf)
 		srvErrCh <- err
 		if err != nil {
 			return
 		}
 		serverID <- conn.SessionID()
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 		for ch := range chans {
 			ch.Reject(Prohibited, "")
 		}
 	}()
 
 	cliErrCh := make(chan error, 1)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
 		cliErrCh <- err
 		if err != nil {
 			return
 		}
 		clientID <- conn.SessionID()
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 		for ch := range chans {
 			ch.Reject(Prohibited, "")
 		}
@@ -738,6 +767,8 @@
 	serverConf.AddHostKey(testSigners["rsa"])
 	serverConf.AddHostKey(testSigners["ecdsa"])
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
 	connect := func(clientConf *ClientConfig, want string) {
 		var alg string
 		clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
@@ -751,7 +782,11 @@
 		defer c1.Close()
 		defer c2.Close()
 
-		go NewServerConn(c1, serverConf)
+		wg.Add(1)
+		go func() {
+			NewServerConn(c1, serverConf)
+			wg.Done()
+		}()
 		_, _, _, err = NewClientConn(c2, "", clientConf)
 		if err != nil {
 			t.Fatalf("NewClientConn: %v", err)
@@ -785,7 +820,11 @@
 	defer c1.Close()
 	defer c2.Close()
 
-	go NewServerConn(c1, serverConf)
+	wg.Add(1)
+	go func() {
+		NewServerConn(c1, serverConf)
+		wg.Done()
+	}()
 	clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
 	_, _, _, err = NewClientConn(c2, "", clientConf)
 	if err == nil {
@@ -818,14 +857,22 @@
 		User:            someUsername,
 	}
 
+	var wg sync.WaitGroup
+	t.Cleanup(wg.Wait)
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		_, chans, reqs, err := NewServerConn(c1, serverConf)
 		if err != nil {
 			t.Errorf("server handshake: %v", err)
 			userCh <- "error"
 			return
 		}
-		go DiscardRequests(reqs)
+		wg.Add(1)
+		go func() {
+			DiscardRequests(reqs)
+			wg.Done()
+		}()
 		for ch := range chans {
 			ch.Reject(Prohibited, "")
 		}