ssh: fix call to Fatalf from a non-test goroutine

Also fix some redundant type declarations.

Change-Id: Iad2950b67b1ec2e2590c59393b8ad15421ed3add
GitHub-Last-Rev: 41cf552f11387208491dee7b867050475043b25e
GitHub-Pull-Request: golang/crypto#263
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/505798
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
index c27eaa9..fdc8000 100644
--- a/ssh/agent/client_test.go
+++ b/ssh/agent/client_test.go
@@ -369,7 +369,8 @@
 	go func() {
 		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
 		if err != nil {
-			t.Fatalf("Server: %v", err)
+			t.Errorf("NewServerConn error: %v", err)
+			return
 		}
 		conn.Close()
 	}()
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
index 038018e..0af8545 100644
--- a/ssh/agent/server_test.go
+++ b/ssh/agent/server_test.go
@@ -53,10 +53,11 @@
 	incoming := make(chan *ssh.ServerConn, 1)
 	go func() {
 		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
-		if err != nil {
-			t.Fatalf("Server: %v", err)
-		}
 		incoming <- conn
+		if err != nil {
+			t.Errorf("NewServerConn error: %v", err)
+			return
+		}
 	}()
 
 	conf := ssh.ClientConfig{
@@ -71,8 +72,10 @@
 	if err := ForwardToRemote(client, socket); err != nil {
 		t.Fatalf("SetupForwardAgent: %v", err)
 	}
-
 	server := <-incoming
+	if server == nil {
+		t.Fatal("Unable to get server")
+	}
 	ch, reqs, err := server.OpenChannel(channelType, nil)
 	if err != nil {
 		t.Fatalf("OpenChannel(%q): %v", channelType, err)
diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go
index a13235d..b356330 100644
--- a/ssh/benchmark_test.go
+++ b/ssh/benchmark_test.go
@@ -6,6 +6,7 @@
 
 import (
 	"errors"
+	"fmt"
 	"io"
 	"net"
 	"testing"
@@ -90,16 +91,16 @@
 	go func() {
 		newCh, err := server.Accept()
 		if err != nil {
-			b.Fatalf("Client: %v", err)
+			panic(fmt.Sprintf("Client: %v", err))
 		}
 		ch, incoming, err := newCh.Accept()
 		if err != nil {
-			b.Fatalf("Accept: %v", err)
+			panic(fmt.Sprintf("Accept: %v", err))
 		}
 		go DiscardRequests(incoming)
 		for i := 0; i < b.N; i++ {
 			if _, err := io.ReadFull(ch, output); err != nil {
-				b.Fatalf("ReadFull: %v", err)
+				panic(fmt.Sprintf("ReadFull: %v", err))
 			}
 		}
 		ch.Close()
diff --git a/ssh/common_test.go b/ssh/common_test.go
index 96744dc..a7beee8 100644
--- a/ssh/common_test.go
+++ b/ssh/common_test.go
@@ -82,11 +82,11 @@
 	}
 
 	cases := []testcase{
-		testcase{
+		{
 			name: "standard",
 		},
 
-		testcase{
+		{
 			name: "no common hostkey",
 			serverIn: kexInitMsg{
 				ServerHostKeyAlgos: []string{"hostkey2"},
@@ -94,7 +94,7 @@
 			wantErr: true,
 		},
 
-		testcase{
+		{
 			name: "no common kex",
 			serverIn: kexInitMsg{
 				KexAlgos: []string{"kex2"},
@@ -102,7 +102,7 @@
 			wantErr: true,
 		},
 
-		testcase{
+		{
 			name: "no common cipher",
 			serverIn: kexInitMsg{
 				CiphersClientServer: []string{"cipher2"},
@@ -110,7 +110,7 @@
 			wantErr: true,
 		},
 
-		testcase{
+		{
 			name: "client decides cipher",
 			serverIn: kexInitMsg{
 				CiphersClientServer: []string{"cipher1", "cipher2"},
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index f190cbf..879143a 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -148,6 +148,7 @@
 	clientDone := make(chan int, 0)
 	gotHalf := make(chan int, 0)
 	const N = 20
+	errorCh := make(chan error, 1)
 
 	go func() {
 		defer close(clientDone)
@@ -158,7 +159,9 @@
 		for i := 0; i < N; i++ {
 			p := []byte{msgRequestSuccess, byte(i)}
 			if err := trC.writePacket(p); err != nil {
-				t.Fatalf("sendPacket: %v", err)
+				errorCh <- err
+				trC.Close()
+				return
 			}
 			if (i % 10) == 5 {
 				<-gotHalf
@@ -177,16 +180,15 @@
 				checker.waitCall <- 1
 			}
 		}
+		errorCh <- nil
 	}()
 
 	// Server checks that client messages come in cleanly
 	i := 0
-	err = nil
 	for ; i < N; i++ {
-		var p []byte
-		p, err = trS.readPacket()
-		if err != nil {
-			break
+		p, err := trS.readPacket()
+		if err != nil && err != io.EOF {
+			t.Fatalf("server error: %v", err)
 		}
 		if (i % 10) == 5 {
 			gotHalf <- 1
@@ -198,8 +200,8 @@
 		}
 	}
 	<-clientDone
-	if err != nil && err != io.EOF {
-		t.Fatalf("server error: %v", err)
+	if err := <-errorCh; err != nil {
+		t.Fatalf("sendPacket: %v", err)
 	}
 	if i != N {
 		t.Errorf("received %d messages, want 10.", i)
@@ -345,16 +347,16 @@
 
 	// While we read out the packet, a key change will be
 	// initiated.
-	done := make(chan int, 1)
+	errorCh := make(chan error, 1)
 	go func() {
-		defer close(done)
-		if _, err := trC.readPacket(); err != nil {
-			t.Fatalf("readPacket(client): %v", err)
-		}
-
+		_, err := trC.readPacket()
+		errorCh <- err
 	}()
 
-	<-done
+	if err := <-errorCh; err != nil {
+		t.Fatalf("readPacket(client): %v", err)
+	}
+
 	<-sync.called
 }
 
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index 393017c..1db3be5 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -5,6 +5,8 @@
 package ssh
 
 import (
+	"errors"
+	"fmt"
 	"io"
 	"sync"
 	"testing"
@@ -29,14 +31,21 @@
 	go func() {
 		newCh, ok := <-s.incomingChannels
 		if !ok {
-			t.Fatalf("No incoming channel")
+			t.Error("no incoming channel")
+			close(res)
+			return
 		}
 		if newCh.ChannelType() != "chan" {
-			t.Fatalf("got type %q want chan", newCh.ChannelType())
+			t.Errorf("got type %q want chan", newCh.ChannelType())
+			newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType()))
+			close(res)
+			return
 		}
 		ch, _, err := newCh.Accept()
 		if err != nil {
-			t.Fatalf("Accept %v", err)
+			t.Errorf("accept: %v", err)
+			close(res)
+			return
 		}
 		res <- ch.(*channel)
 	}()
@@ -45,8 +54,12 @@
 	if err != nil {
 		t.Fatalf("OpenChannel: %v", err)
 	}
+	w := <-res
+	if w == nil {
+		t.Fatal("unable to get write channel")
+	}
 
-	return <-res, ch, c
+	return w, ch, c
 }
 
 // Test that stderr and stdout can be addressed from different
@@ -74,14 +87,14 @@
 	go func() {
 		c, err := io.ReadAll(reader)
 		if string(c) != magic {
-			t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
+			t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err)
 		}
 		rd.Done()
 	}()
 	go func() {
 		c, err := io.ReadAll(reader.Stderr())
 		if string(c) != magic {
-			t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
+			t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err)
 		}
 		rd.Done()
 	}()
@@ -102,11 +115,13 @@
 	go func() {
 		_, err := s.Write([]byte(magic))
 		if err != nil {
-			t.Fatalf("Write: %v", err)
+			t.Errorf("Write: %v", err)
+			return
 		}
 		_, err = s.Extended(1).Write([]byte(magicExt))
 		if err != nil {
-			t.Fatalf("Write: %v", err)
+			t.Errorf("Write: %v", err)
+			return
 		}
 	}()
 
@@ -215,10 +230,13 @@
 	go func() {
 		ch, ok := <-server.incomingChannels
 		if !ok {
-			t.Fatalf("Accept")
+			t.Error("cannot accept channel")
+			return
 		}
 		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
-			t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+			t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+			ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String())
+			return
 		}
 		ch.Reject(RejectionReason(42), "message")
 	}()
@@ -294,7 +312,7 @@
 	defer serverPipe.Close()
 	defer client.Close()
 
-	kDone := make(chan struct{})
+	kDone := make(chan error, 1)
 	go func() {
 		// Ignore unknown channel messages that don't want a reply.
 		err := serverPipe.writePacket(Marshal(channelRequestMsg{
@@ -304,7 +322,8 @@
 			RequestSpecificData: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Send a keepalive, which should get a channel failure message
@@ -316,44 +335,53 @@
 			RequestSpecificData: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		packet, err := serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		decoded, err := decode(packet)
 		if err != nil {
-			t.Fatalf("decode failed: %v", err)
+			kDone <- fmt.Errorf("decode failed: %w", err)
+			return
 		}
 
 		switch msg := decoded.(type) {
 		case *channelRequestFailureMsg:
 			if msg.PeersID != 2 {
-				t.Fatalf("received response to wrong message: %v", msg)
+				kDone <- fmt.Errorf("received response to wrong message: %v", msg)
+				return
+
 			}
 		default:
-			t.Fatalf("unexpected channel message: %v", msg)
+			kDone <- fmt.Errorf("unexpected channel message: %v", msg)
+			return
 		}
 
-		kDone <- struct{}{}
+		kDone <- nil
 
 		// Receive and respond to the keepalive to confirm the mux is
 		// still processing requests.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgGlobalRequest {
-			t.Fatalf("expected global request")
+			kDone <- errors.New("expected global request")
+			return
 		}
 
 		err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
 			Data: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("failed to send failure msg: %v", err)
+			kDone <- fmt.Errorf("failed to send failure msg: %w", err)
+			return
 		}
 
 		close(kDone)
@@ -362,7 +390,10 @@
 	// Wait for the server to send the keepalive message and receive back a
 	// response.
 	select {
-	case <-kDone:
+	case err := <-kDone:
+		if err != nil {
+			t.Fatal(err)
+		}
 	case <-time.After(10 * time.Second):
 		t.Fatalf("server never received ack")
 	}
@@ -373,7 +404,10 @@
 	}
 
 	select {
-	case <-kDone:
+	case err := <-kDone:
+		if err != nil {
+			t.Fatal(err)
+		}
 	case <-time.After(10 * time.Second):
 		t.Fatalf("server never shut down")
 	}
@@ -385,20 +419,23 @@
 	defer serverPipe.Close()
 	defer client.Close()
 
-	kDone := make(chan struct{})
+	kDone := make(chan error, 1)
 	go func() {
 		// Open the channel.
 		packet, err := serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgChannelOpen {
-			t.Fatalf("expected chan open")
+			kDone <- errors.New("expected chan open")
+			return
 		}
 
 		var openMsg channelOpenMsg
 		if err := Unmarshal(packet, &openMsg); err != nil {
-			t.Fatalf("unmarshal: %v", err)
+			kDone <- fmt.Errorf("unmarshal: %w", err)
+			return
 		}
 
 		// Send back the opened channel confirmation.
@@ -409,7 +446,8 @@
 			MaxPacketSize: channelMaxPacket,
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Close the channel.
@@ -417,7 +455,8 @@
 			PeersID: openMsg.PeersID,
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Send a keepalive message on the channel we just closed.
@@ -428,43 +467,51 @@
 			RequestSpecificData: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("send: %v", err)
+			kDone <- fmt.Errorf("send: %w", err)
+			return
 		}
 
 		// Receive the channel closed response.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgChannelClose {
-			t.Fatalf("expected channel close")
+			kDone <- errors.New("expected channel close")
+			return
 		}
 
 		// Receive the keepalive response failure.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgChannelFailure {
-			t.Fatalf("expected channel close")
+			kDone <- errors.New("expected channel failure")
+			return
 		}
-		kDone <- struct{}{}
+		kDone <- nil
 
 		// Receive and respond to the keepalive to confirm the mux is
 		// still processing requests.
 		packet, err = serverPipe.readPacket()
 		if err != nil {
-			t.Fatalf("read packet: %v", err)
+			kDone <- fmt.Errorf("read packet: %w", err)
+			return
 		}
 		if packet[0] != msgGlobalRequest {
-			t.Fatalf("expected global request")
+			kDone <- errors.New("expected global request")
+			return
 		}
 
 		err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
 			Data: []byte{},
 		}))
 		if err != nil {
-			t.Fatalf("failed to send failure msg: %v", err)
+			kDone <- fmt.Errorf("failed to send failure msg: %w", err)
+			return
 		}
 
 		close(kDone)
diff --git a/ssh/session_test.go b/ssh/session_test.go
index c4b9f0e..521677f 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -36,7 +36,8 @@
 
 		conn, chans, reqs, err := NewServerConn(c1, &conf)
 		if err != nil {
-			t.Fatalf("Unable to handshake: %v", err)
+			t.Errorf("Unable to handshake: %v", err)
+			return
 		}
 		go DiscardRequests(reqs)
 
@@ -647,10 +648,12 @@
 		User:            "user",
 	}
 
+	srvErrCh := make(chan error, 1)
 	go func() {
 		conn, chans, reqs, err := NewServerConn(c1, serverConf)
+		srvErrCh <- err
 		if err != nil {
-			t.Fatalf("server handshake: %v", err)
+			return
 		}
 		serverID <- conn.SessionID()
 		go DiscardRequests(reqs)
@@ -659,10 +662,12 @@
 		}
 	}()
 
+	cliErrCh := make(chan error, 1)
 	go func() {
 		conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
+		cliErrCh <- err
 		if err != nil {
-			t.Fatalf("client handshake: %v", err)
+			return
 		}
 		clientID <- conn.SessionID()
 		go DiscardRequests(reqs)
@@ -671,6 +676,14 @@
 		}
 	}()
 
+	if err := <-srvErrCh; err != nil {
+		t.Fatalf("server handshake: %v", err)
+	}
+
+	if err := <-cliErrCh; err != nil {
+		t.Fatalf("client handshake: %v", err)
+	}
+
 	s := <-serverID
 	c := <-clientID
 	if bytes.Compare(s, c) != 0 {
diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go
index 6c253a7..403d736 100644
--- a/ssh/test/multi_auth_test.go
+++ b/ssh/test/multi_auth_test.go
@@ -77,27 +77,27 @@
 func TestMultiAuth(t *testing.T) {
 	testCases := []multiAuthTestCase{
 		// Test password,publickey authentication, assert that password callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:         []string{"password", "publickey"},
 			expectedPasswordCbs: 1,
 		},
 		// Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:       []string{"keyboard-interactive", "publickey"},
 			expectedKbdIntCbs: 1,
 		},
 		// Test publickey,password authentication, assert that password callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:         []string{"publickey", "password"},
 			expectedPasswordCbs: 1,
 		},
 		// Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time
-		multiAuthTestCase{
+		{
 			authMethods:       []string{"publickey", "keyboard-interactive"},
 			expectedKbdIntCbs: 1,
 		},
 		// Test password,password authentication, assert that password callback is called 2 times
-		multiAuthTestCase{
+		{
 			authMethods:         []string{"password", "password"},
 			expectedPasswordCbs: 2,
 		},