ssh: fix data race in dh group exchange sha256
Fixes golang/go#37607
Change-Id: Iedf6522ec9b9a676ac51c054407a6aef894885f5
GitHub-Last-Rev: 8cb2460c59d2e32bc3f0480bcd7867a113361c67
GitHub-Pull-Request: golang/crypto#126
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/222078
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/kex.go b/ssh/kex.go
index 6c3c648..7eedb20 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -572,7 +572,7 @@
return new(big.Int).Exp(theirPublic, myPrivate, gex.p), nil
}
-func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
+func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
// Send GexRequest
kexDHGexRequest := kexDHGexRequestMsg{
MinBits: dhGroupExchangeMinimumBits,
@@ -677,7 +677,7 @@
// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
//
// This is a minimal implementation to satisfy the automated tests.
-func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
// Receive GexRequest
packet, err := c.readPacket()
if err != nil {
diff --git a/ssh/kex_test.go b/ssh/kex_test.go
index 12ca0ac..1416b17 100644
--- a/ssh/kex_test.go
+++ b/ssh/kex_test.go
@@ -9,9 +9,14 @@
import (
"crypto/rand"
"reflect"
+ "sync"
"testing"
)
+// Runs multiple key exchanges concurrent to detect potential data races with
+// kex obtained from the global kexAlgoMap.
+// This test needs to be executed using the race detector in order to detect
+// race conditions.
func TestKexes(t *testing.T) {
type kexResultErr struct {
result *kexResult
@@ -19,32 +24,42 @@
}
for name, kex := range kexAlgoMap {
- a, b := memPipe()
+ t.Run(name, func(t *testing.T) {
+ wg := sync.WaitGroup{}
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ a, b := memPipe()
- s := make(chan kexResultErr, 1)
- c := make(chan kexResultErr, 1)
- var magics handshakeMagics
- go func() {
- r, e := kex.Client(a, rand.Reader, &magics)
- a.Close()
- c <- kexResultErr{r, e}
- }()
- go func() {
- r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
- b.Close()
- s <- kexResultErr{r, e}
- }()
+ s := make(chan kexResultErr, 1)
+ c := make(chan kexResultErr, 1)
+ var magics handshakeMagics
+ go func() {
+ r, e := kex.Client(a, rand.Reader, &magics)
+ a.Close()
+ c <- kexResultErr{r, e}
+ }()
+ go func() {
+ r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
+ b.Close()
+ s <- kexResultErr{r, e}
+ }()
- clientRes := <-c
- serverRes := <-s
- if clientRes.err != nil {
- t.Errorf("client: %v", clientRes.err)
- }
- if serverRes.err != nil {
- t.Errorf("server: %v", serverRes.err)
- }
- if !reflect.DeepEqual(clientRes.result, serverRes.result) {
- t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
- }
+ clientRes := <-c
+ serverRes := <-s
+ if clientRes.err != nil {
+ t.Errorf("client: %v", clientRes.err)
+ }
+ if serverRes.err != nil {
+ t.Errorf("server: %v", serverRes.err)
+ }
+ if !reflect.DeepEqual(clientRes.result, serverRes.result) {
+ t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
+ }
+ }()
+ }
+ wg.Wait()
+ })
}
}