ssh: fix call to Fatalf from a non-test goroutine
Also fix some redundant type declarations.
Change-Id: Iad2950b67b1ec2e2590c59393b8ad15421ed3add
GitHub-Last-Rev: 41cf552f11387208491dee7b867050475043b25e
GitHub-Pull-Request: golang/crypto#263
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/505798
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Filippo Valsorda <filippo@golang.org>
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
index c27eaa9..fdc8000 100644
--- a/ssh/agent/client_test.go
+++ b/ssh/agent/client_test.go
@@ -369,7 +369,8 @@
go func() {
conn, _, _, err := ssh.NewServerConn(a, &serverConf)
if err != nil {
- t.Fatalf("Server: %v", err)
+ t.Errorf("NewServerConn error: %v", err)
+ return
}
conn.Close()
}()
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
index 038018e..0af8545 100644
--- a/ssh/agent/server_test.go
+++ b/ssh/agent/server_test.go
@@ -53,10 +53,11 @@
incoming := make(chan *ssh.ServerConn, 1)
go func() {
conn, _, _, err := ssh.NewServerConn(a, &serverConf)
- if err != nil {
- t.Fatalf("Server: %v", err)
- }
incoming <- conn
+ if err != nil {
+ t.Errorf("NewServerConn error: %v", err)
+ return
+ }
}()
conf := ssh.ClientConfig{
@@ -71,8 +72,10 @@
if err := ForwardToRemote(client, socket); err != nil {
t.Fatalf("SetupForwardAgent: %v", err)
}
-
server := <-incoming
+ if server == nil {
+ t.Fatal("Unable to get server")
+ }
ch, reqs, err := server.OpenChannel(channelType, nil)
if err != nil {
t.Fatalf("OpenChannel(%q): %v", channelType, err)
diff --git a/ssh/benchmark_test.go b/ssh/benchmark_test.go
index a13235d..b356330 100644
--- a/ssh/benchmark_test.go
+++ b/ssh/benchmark_test.go
@@ -6,6 +6,7 @@
import (
"errors"
+ "fmt"
"io"
"net"
"testing"
@@ -90,16 +91,16 @@
go func() {
newCh, err := server.Accept()
if err != nil {
- b.Fatalf("Client: %v", err)
+ panic(fmt.Sprintf("Client: %v", err))
}
ch, incoming, err := newCh.Accept()
if err != nil {
- b.Fatalf("Accept: %v", err)
+ panic(fmt.Sprintf("Accept: %v", err))
}
go DiscardRequests(incoming)
for i := 0; i < b.N; i++ {
if _, err := io.ReadFull(ch, output); err != nil {
- b.Fatalf("ReadFull: %v", err)
+ panic(fmt.Sprintf("ReadFull: %v", err))
}
}
ch.Close()
diff --git a/ssh/common_test.go b/ssh/common_test.go
index 96744dc..a7beee8 100644
--- a/ssh/common_test.go
+++ b/ssh/common_test.go
@@ -82,11 +82,11 @@
}
cases := []testcase{
- testcase{
+ {
name: "standard",
},
- testcase{
+ {
name: "no common hostkey",
serverIn: kexInitMsg{
ServerHostKeyAlgos: []string{"hostkey2"},
@@ -94,7 +94,7 @@
wantErr: true,
},
- testcase{
+ {
name: "no common kex",
serverIn: kexInitMsg{
KexAlgos: []string{"kex2"},
@@ -102,7 +102,7 @@
wantErr: true,
},
- testcase{
+ {
name: "no common cipher",
serverIn: kexInitMsg{
CiphersClientServer: []string{"cipher2"},
@@ -110,7 +110,7 @@
wantErr: true,
},
- testcase{
+ {
name: "client decides cipher",
serverIn: kexInitMsg{
CiphersClientServer: []string{"cipher1", "cipher2"},
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index f190cbf..879143a 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -148,6 +148,7 @@
clientDone := make(chan int, 0)
gotHalf := make(chan int, 0)
const N = 20
+ errorCh := make(chan error, 1)
go func() {
defer close(clientDone)
@@ -158,7 +159,9 @@
for i := 0; i < N; i++ {
p := []byte{msgRequestSuccess, byte(i)}
if err := trC.writePacket(p); err != nil {
- t.Fatalf("sendPacket: %v", err)
+ errorCh <- err
+ trC.Close()
+ return
}
if (i % 10) == 5 {
<-gotHalf
@@ -177,16 +180,15 @@
checker.waitCall <- 1
}
}
+ errorCh <- nil
}()
// Server checks that client messages come in cleanly
i := 0
- err = nil
for ; i < N; i++ {
- var p []byte
- p, err = trS.readPacket()
- if err != nil {
- break
+ p, err := trS.readPacket()
+ if err != nil && err != io.EOF {
+ t.Fatalf("server error: %v", err)
}
if (i % 10) == 5 {
gotHalf <- 1
@@ -198,8 +200,8 @@
}
}
<-clientDone
- if err != nil && err != io.EOF {
- t.Fatalf("server error: %v", err)
+ if err := <-errorCh; err != nil {
+ t.Fatalf("sendPacket: %v", err)
}
if i != N {
t.Errorf("received %d messages, want 10.", i)
@@ -345,16 +347,16 @@
// While we read out the packet, a key change will be
// initiated.
- done := make(chan int, 1)
+ errorCh := make(chan error, 1)
go func() {
- defer close(done)
- if _, err := trC.readPacket(); err != nil {
- t.Fatalf("readPacket(client): %v", err)
- }
-
+ _, err := trC.readPacket()
+ errorCh <- err
}()
- <-done
+ if err := <-errorCh; err != nil {
+ t.Fatalf("readPacket(client): %v", err)
+ }
+
<-sync.called
}
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index 393017c..1db3be5 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -5,6 +5,8 @@
package ssh
import (
+ "errors"
+ "fmt"
"io"
"sync"
"testing"
@@ -29,14 +31,21 @@
go func() {
newCh, ok := <-s.incomingChannels
if !ok {
- t.Fatalf("No incoming channel")
+ t.Error("no incoming channel")
+ close(res)
+ return
}
if newCh.ChannelType() != "chan" {
- t.Fatalf("got type %q want chan", newCh.ChannelType())
+ t.Errorf("got type %q want chan", newCh.ChannelType())
+ newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType()))
+ close(res)
+ return
}
ch, _, err := newCh.Accept()
if err != nil {
- t.Fatalf("Accept %v", err)
+ t.Errorf("accept: %v", err)
+ close(res)
+ return
}
res <- ch.(*channel)
}()
@@ -45,8 +54,12 @@
if err != nil {
t.Fatalf("OpenChannel: %v", err)
}
+ w := <-res
+ if w == nil {
+ t.Fatal("unable to get write channel")
+ }
- return <-res, ch, c
+ return w, ch, c
}
// Test that stderr and stdout can be addressed from different
@@ -74,14 +87,14 @@
go func() {
c, err := io.ReadAll(reader)
if string(c) != magic {
- t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
+ t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
go func() {
c, err := io.ReadAll(reader.Stderr())
if string(c) != magic {
- t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
+ t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err)
}
rd.Done()
}()
@@ -102,11 +115,13 @@
go func() {
_, err := s.Write([]byte(magic))
if err != nil {
- t.Fatalf("Write: %v", err)
+ t.Errorf("Write: %v", err)
+ return
}
_, err = s.Extended(1).Write([]byte(magicExt))
if err != nil {
- t.Fatalf("Write: %v", err)
+ t.Errorf("Write: %v", err)
+ return
}
}()
@@ -215,10 +230,13 @@
go func() {
ch, ok := <-server.incomingChannels
if !ok {
- t.Fatalf("Accept")
+ t.Error("cannot accept channel")
+ return
}
if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
- t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+ t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
+ ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String())
+ return
}
ch.Reject(RejectionReason(42), "message")
}()
@@ -294,7 +312,7 @@
defer serverPipe.Close()
defer client.Close()
- kDone := make(chan struct{})
+ kDone := make(chan error, 1)
go func() {
// Ignore unknown channel messages that don't want a reply.
err := serverPipe.writePacket(Marshal(channelRequestMsg{
@@ -304,7 +322,8 @@
RequestSpecificData: []byte{},
}))
if err != nil {
- t.Fatalf("send: %v", err)
+ kDone <- fmt.Errorf("send: %w", err)
+ return
}
// Send a keepalive, which should get a channel failure message
@@ -316,44 +335,53 @@
RequestSpecificData: []byte{},
}))
if err != nil {
- t.Fatalf("send: %v", err)
+ kDone <- fmt.Errorf("send: %w", err)
+ return
}
packet, err := serverPipe.readPacket()
if err != nil {
- t.Fatalf("read packet: %v", err)
+ kDone <- fmt.Errorf("read packet: %w", err)
+ return
}
decoded, err := decode(packet)
if err != nil {
- t.Fatalf("decode failed: %v", err)
+ kDone <- fmt.Errorf("decode failed: %w", err)
+ return
}
switch msg := decoded.(type) {
case *channelRequestFailureMsg:
if msg.PeersID != 2 {
- t.Fatalf("received response to wrong message: %v", msg)
+ kDone <- fmt.Errorf("received response to wrong message: %v", msg)
+ return
+
}
default:
- t.Fatalf("unexpected channel message: %v", msg)
+ kDone <- fmt.Errorf("unexpected channel message: %v", msg)
+ return
}
- kDone <- struct{}{}
+ kDone <- nil
// Receive and respond to the keepalive to confirm the mux is
// still processing requests.
packet, err = serverPipe.readPacket()
if err != nil {
- t.Fatalf("read packet: %v", err)
+ kDone <- fmt.Errorf("read packet: %w", err)
+ return
}
if packet[0] != msgGlobalRequest {
- t.Fatalf("expected global request")
+ kDone <- errors.New("expected global request")
+ return
}
err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
Data: []byte{},
}))
if err != nil {
- t.Fatalf("failed to send failure msg: %v", err)
+ kDone <- fmt.Errorf("failed to send failure msg: %w", err)
+ return
}
close(kDone)
@@ -362,7 +390,10 @@
// Wait for the server to send the keepalive message and receive back a
// response.
select {
- case <-kDone:
+ case err := <-kDone:
+ if err != nil {
+ t.Fatal(err)
+ }
case <-time.After(10 * time.Second):
t.Fatalf("server never received ack")
}
@@ -373,7 +404,10 @@
}
select {
- case <-kDone:
+ case err := <-kDone:
+ if err != nil {
+ t.Fatal(err)
+ }
case <-time.After(10 * time.Second):
t.Fatalf("server never shut down")
}
@@ -385,20 +419,23 @@
defer serverPipe.Close()
defer client.Close()
- kDone := make(chan struct{})
+ kDone := make(chan error, 1)
go func() {
// Open the channel.
packet, err := serverPipe.readPacket()
if err != nil {
- t.Fatalf("read packet: %v", err)
+ kDone <- fmt.Errorf("read packet: %w", err)
+ return
}
if packet[0] != msgChannelOpen {
- t.Fatalf("expected chan open")
+ kDone <- errors.New("expected chan open")
+ return
}
var openMsg channelOpenMsg
if err := Unmarshal(packet, &openMsg); err != nil {
- t.Fatalf("unmarshal: %v", err)
+ kDone <- fmt.Errorf("unmarshal: %w", err)
+ return
}
// Send back the opened channel confirmation.
@@ -409,7 +446,8 @@
MaxPacketSize: channelMaxPacket,
}))
if err != nil {
- t.Fatalf("send: %v", err)
+ kDone <- fmt.Errorf("send: %w", err)
+ return
}
// Close the channel.
@@ -417,7 +455,8 @@
PeersID: openMsg.PeersID,
}))
if err != nil {
- t.Fatalf("send: %v", err)
+ kDone <- fmt.Errorf("send: %w", err)
+ return
}
// Send a keepalive message on the channel we just closed.
@@ -428,43 +467,51 @@
RequestSpecificData: []byte{},
}))
if err != nil {
- t.Fatalf("send: %v", err)
+ kDone <- fmt.Errorf("send: %w", err)
+ return
}
// Receive the channel closed response.
packet, err = serverPipe.readPacket()
if err != nil {
- t.Fatalf("read packet: %v", err)
+ kDone <- fmt.Errorf("read packet: %w", err)
+ return
}
if packet[0] != msgChannelClose {
- t.Fatalf("expected channel close")
+ kDone <- errors.New("expected channel close")
+ return
}
// Receive the keepalive response failure.
packet, err = serverPipe.readPacket()
if err != nil {
- t.Fatalf("read packet: %v", err)
+ kDone <- fmt.Errorf("read packet: %w", err)
+ return
}
if packet[0] != msgChannelFailure {
- t.Fatalf("expected channel close")
+ kDone <- errors.New("expected channel failure")
+ return
}
- kDone <- struct{}{}
+ kDone <- nil
// Receive and respond to the keepalive to confirm the mux is
// still processing requests.
packet, err = serverPipe.readPacket()
if err != nil {
- t.Fatalf("read packet: %v", err)
+ kDone <- fmt.Errorf("read packet: %w", err)
+ return
}
if packet[0] != msgGlobalRequest {
- t.Fatalf("expected global request")
+ kDone <- errors.New("expected global request")
+ return
}
err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{
Data: []byte{},
}))
if err != nil {
- t.Fatalf("failed to send failure msg: %v", err)
+ kDone <- fmt.Errorf("failed to send failure msg: %w", err)
+ return
}
close(kDone)
diff --git a/ssh/session_test.go b/ssh/session_test.go
index c4b9f0e..521677f 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -36,7 +36,8 @@
conn, chans, reqs, err := NewServerConn(c1, &conf)
if err != nil {
- t.Fatalf("Unable to handshake: %v", err)
+ t.Errorf("Unable to handshake: %v", err)
+ return
}
go DiscardRequests(reqs)
@@ -647,10 +648,12 @@
User: "user",
}
+ srvErrCh := make(chan error, 1)
go func() {
conn, chans, reqs, err := NewServerConn(c1, serverConf)
+ srvErrCh <- err
if err != nil {
- t.Fatalf("server handshake: %v", err)
+ return
}
serverID <- conn.SessionID()
go DiscardRequests(reqs)
@@ -659,10 +662,12 @@
}
}()
+ cliErrCh := make(chan error, 1)
go func() {
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
+ cliErrCh <- err
if err != nil {
- t.Fatalf("client handshake: %v", err)
+ return
}
clientID <- conn.SessionID()
go DiscardRequests(reqs)
@@ -671,6 +676,14 @@
}
}()
+ if err := <-srvErrCh; err != nil {
+ t.Fatalf("server handshake: %v", err)
+ }
+
+ if err := <-cliErrCh; err != nil {
+ t.Fatalf("client handshake: %v", err)
+ }
+
s := <-serverID
c := <-clientID
if bytes.Compare(s, c) != 0 {
diff --git a/ssh/test/multi_auth_test.go b/ssh/test/multi_auth_test.go
index 6c253a7..403d736 100644
--- a/ssh/test/multi_auth_test.go
+++ b/ssh/test/multi_auth_test.go
@@ -77,27 +77,27 @@
func TestMultiAuth(t *testing.T) {
testCases := []multiAuthTestCase{
// Test password,publickey authentication, assert that password callback is called 1 time
- multiAuthTestCase{
+ {
authMethods: []string{"password", "publickey"},
expectedPasswordCbs: 1,
},
// Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time
- multiAuthTestCase{
+ {
authMethods: []string{"keyboard-interactive", "publickey"},
expectedKbdIntCbs: 1,
},
// Test publickey,password authentication, assert that password callback is called 1 time
- multiAuthTestCase{
+ {
authMethods: []string{"publickey", "password"},
expectedPasswordCbs: 1,
},
// Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time
- multiAuthTestCase{
+ {
authMethods: []string{"publickey", "keyboard-interactive"},
expectedKbdIntCbs: 1,
},
// Test password,password authentication, assert that password callback is called 2 times
- multiAuthTestCase{
+ {
authMethods: []string{"password", "password"},
expectedPasswordCbs: 2,
},