ssh: eliminate some goroutine leaks in tests and examples
This should fix the "Log in goroutine" panic seen in
https://build.golang.org/log/e42bf69fc002113dbccfe602a6c67fd52e8f31df,
as well as a few other related leaks. It also helps to verify that
none of the functions under test deadlock unexpectedly.
See https://go.dev/wiki/CodeReviewComments#goroutine-lifetimes.
Updates golang/go#58901.
Change-Id: Ica943444db381ae1accb80b101ea646e28ebf4f9
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/541095
Auto-Submit: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Heschi Kreinick <heschi@google.com>
diff --git a/ssh/example_test.go b/ssh/example_test.go
index 0a6b076..3920832 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -16,6 +16,7 @@
"os"
"path/filepath"
"strings"
+ "sync"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
@@ -98,8 +99,15 @@
}
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
// The incoming Request channel must be serviced.
- go ssh.DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ ssh.DiscardRequests(reqs)
+ wg.Done()
+ }()
// Service the incoming Channel channel.
for newChannel := range chans {
@@ -119,16 +127,22 @@
// Sessions have out-of-band requests such as "shell",
// "pty-req" and "env". Here we handle only the
// "shell" request.
+ wg.Add(1)
go func(in <-chan *ssh.Request) {
for req := range in {
req.Reply(req.Type == "shell", nil)
}
+ wg.Done()
}(requests)
term := terminal.NewTerminal(channel, "> ")
+ wg.Add(1)
go func() {
- defer channel.Close()
+ defer func() {
+ channel.Close()
+ wg.Done()
+ }()
for {
line, err := term.ReadLine()
if err != nil {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index 1db3be5..eae637d 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -10,7 +10,6 @@
"io"
"sync"
"testing"
- "time"
)
func muxPair() (*mux, *mux) {
@@ -112,7 +111,11 @@
magic := "hello world"
magicExt := "hello stderr"
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
_, err := s.Write([]byte(magic))
if err != nil {
t.Errorf("Write: %v", err)
@@ -152,13 +155,15 @@
defer writer.Close()
defer mux.Close()
- wDone := make(chan int, 1)
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
writer.Write(make([]byte, 1))
- wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
@@ -175,7 +180,6 @@
if _, err := reader.SendRequest("hello", true, nil); err == nil {
t.Errorf("SendRequest succeeded.")
}
- <-wDone
}
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
@@ -184,20 +188,21 @@
defer writer.Close()
defer mux.Close()
- wDone := make(chan int, 1)
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
- wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
reader.Close()
- <-wDone
}
func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
@@ -206,20 +211,21 @@
defer writer.Close()
defer mux.Close()
- wDone := make(chan int, 1)
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
t.Errorf("could not fill window: %v", err)
}
if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
t.Errorf("got %v, want EOF for unblock write", err)
}
- wDone <- 1
}()
writer.remoteWin.waitWriterBlocked()
mux.Close()
- <-wDone
}
func TestMuxReject(t *testing.T) {
@@ -227,7 +233,12 @@
defer server.Close()
defer client.Close()
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
+
ch, ok := <-server.incomingChannels
if !ok {
t.Error("cannot accept channel")
@@ -267,6 +278,7 @@
var received int
var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
wg.Add(1)
go func() {
for r := range server.incomingRequests {
@@ -295,7 +307,6 @@
}
if ok {
t.Errorf("SendRequest(no): %v", ok)
-
}
client.Close()
@@ -389,13 +400,8 @@
// Wait for the server to send the keepalive message and receive back a
// response.
- select {
- case err := <-kDone:
- if err != nil {
- t.Fatal(err)
- }
- case <-time.After(10 * time.Second):
- t.Fatalf("server never received ack")
+ if err := <-kDone; err != nil {
+ t.Fatal(err)
}
// Confirm client hasn't closed.
@@ -403,13 +409,9 @@
t.Fatalf("failed to send keepalive: %v", err)
}
- select {
- case err := <-kDone:
- if err != nil {
- t.Fatal(err)
- }
- case <-time.After(10 * time.Second):
- t.Fatalf("server never shut down")
+ // Wait for the server to shut down.
+ if err := <-kDone; err != nil {
+ t.Fatal(err)
}
}
@@ -525,11 +527,7 @@
defer ch.Close()
// Wait for the server to close the channel and send the keepalive.
- select {
- case <-kDone:
- case <-time.After(10 * time.Second):
- t.Fatalf("server never received ack")
- }
+ <-kDone
// Make sure the channel closed.
if _, ok := <-ch.incomingRequests; ok {
@@ -541,22 +539,29 @@
t.Fatalf("failed to send keepalive: %v", err)
}
- select {
- case <-kDone:
- case <-time.After(10 * time.Second):
- t.Fatalf("server never shut down")
- }
+ // Wait for the server to shut down.
+ <-kDone
}
func TestMuxGlobalRequest(t *testing.T) {
+ var sawPeek bool
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ if !sawPeek {
+ t.Errorf("never saw 'peek' request")
+ }
+ }()
+
clientMux, serverMux := muxPair()
defer serverMux.Close()
defer clientMux.Close()
- var seen bool
+ wg.Add(1)
go func() {
+ defer wg.Done()
for r := range serverMux.incomingRequests {
- seen = seen || r.Type == "peek"
+ sawPeek = sawPeek || r.Type == "peek"
if r.WantReply {
err := r.Reply(r.Type == "yes",
append([]byte(r.Type), r.Payload...))
@@ -586,10 +591,6 @@
t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
ok, data, err)
}
-
- if !seen {
- t.Errorf("never saw 'peek' request")
- }
}
func TestMuxGlobalRequestUnblock(t *testing.T) {
@@ -739,7 +740,13 @@
t.Errorf("could not send packet")
}
- go a.SendRequest("hello", false, nil)
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
+ go func() {
+ a.SendRequest("hello", false, nil)
+ wg.Done()
+ }()
_, ok := <-b.incomingRequests
if ok {
diff --git a/ssh/session_test.go b/ssh/session_test.go
index 521677f..807a913 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -13,6 +13,7 @@
"io"
"math/rand"
"net"
+ "sync"
"testing"
"golang.org/x/crypto/ssh/terminal"
@@ -27,8 +28,14 @@
t.Fatalf("netPipe: %v", err)
}
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
- defer c1.Close()
+ defer func() {
+ c1.Close()
+ wg.Done()
+ }()
conf := ServerConfig{
NoClientAuth: true,
}
@@ -39,7 +46,11 @@
t.Errorf("Unable to handshake: %v", err)
return
}
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for newCh := range chans {
if newCh.ChannelType() != "session" {
@@ -52,8 +63,10 @@
t.Errorf("Accept: %v", err)
continue
}
+ wg.Add(1)
go func() {
handler(ch, inReqs, t)
+ wg.Done()
}()
}
if err := conn.Wait(); err != io.EOF {
@@ -338,8 +351,13 @@
t.Fatal(err)
}
defer session.Close()
- result := make(chan []byte)
+ serverStdin, err := session.StdinPipe()
+ if err != nil {
+ t.Fatalf("StdinPipe failed: %v", err)
+ }
+
+ result := make(chan []byte)
go func() {
defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
@@ -355,10 +373,6 @@
result <- echoedBuf.Bytes()
}()
- serverStdin, err := session.StdinPipe()
- if err != nil {
- t.Fatalf("StdinPipe failed: %v", err)
- }
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Errorf("failed to copy origBuf to serverStdin: %v", err)
@@ -648,29 +662,44 @@
User: "user",
}
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+
srvErrCh := make(chan error, 1)
+ wg.Add(1)
go func() {
+ defer wg.Done()
conn, chans, reqs, err := NewServerConn(c1, serverConf)
srvErrCh <- err
if err != nil {
return
}
serverID <- conn.SessionID()
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
cliErrCh := make(chan error, 1)
+ wg.Add(1)
go func() {
+ defer wg.Done()
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
cliErrCh <- err
if err != nil {
return
}
clientID <- conn.SessionID()
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for ch := range chans {
ch.Reject(Prohibited, "")
}
@@ -738,6 +767,8 @@
serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"])
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
connect := func(clientConf *ClientConfig, want string) {
var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
@@ -751,7 +782,11 @@
defer c1.Close()
defer c2.Close()
- go NewServerConn(c1, serverConf)
+ wg.Add(1)
+ go func() {
+ NewServerConn(c1, serverConf)
+ wg.Done()
+ }()
_, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
@@ -785,7 +820,11 @@
defer c1.Close()
defer c2.Close()
- go NewServerConn(c1, serverConf)
+ wg.Add(1)
+ go func() {
+ NewServerConn(c1, serverConf)
+ wg.Done()
+ }()
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil {
@@ -818,14 +857,22 @@
User: someUsername,
}
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
_, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil {
t.Errorf("server handshake: %v", err)
userCh <- "error"
return
}
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for ch := range chans {
ch.Reject(Prohibited, "")
}