ssh: fix data race in dh group exchange sha256

Fixes golang/go#37607

Change-Id: Iedf6522ec9b9a676ac51c054407a6aef894885f5
GitHub-Last-Rev: 8cb2460c59d2e32bc3f0480bcd7867a113361c67
GitHub-Pull-Request: golang/crypto#126
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/222078
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/ssh/kex.go b/ssh/kex.go
index 6c3c648..7eedb20 100644
--- a/ssh/kex.go
+++ b/ssh/kex.go
@@ -572,7 +572,7 @@
 	return new(big.Int).Exp(theirPublic, myPrivate, gex.p), nil
 }
 
-func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
+func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
 	// Send GexRequest
 	kexDHGexRequest := kexDHGexRequestMsg{
 		MinBits:      dhGroupExchangeMinimumBits,
@@ -677,7 +677,7 @@
 // Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
 //
 // This is a minimal implementation to satisfy the automated tests.
-func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
+func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
 	// Receive GexRequest
 	packet, err := c.readPacket()
 	if err != nil {
diff --git a/ssh/kex_test.go b/ssh/kex_test.go
index 12ca0ac..1416b17 100644
--- a/ssh/kex_test.go
+++ b/ssh/kex_test.go
@@ -9,9 +9,14 @@
 import (
 	"crypto/rand"
 	"reflect"
+	"sync"
 	"testing"
 )
 
+// Runs multiple key exchanges concurrent to detect potential data races with
+// kex obtained from the global kexAlgoMap.
+// This test needs to be executed using the race detector in order to detect
+// race conditions.
 func TestKexes(t *testing.T) {
 	type kexResultErr struct {
 		result *kexResult
@@ -19,32 +24,42 @@
 	}
 
 	for name, kex := range kexAlgoMap {
-		a, b := memPipe()
+		t.Run(name, func(t *testing.T) {
+			wg := sync.WaitGroup{}
+			for i := 0; i < 3; i++ {
+				wg.Add(1)
+				go func() {
+					defer wg.Done()
+					a, b := memPipe()
 
-		s := make(chan kexResultErr, 1)
-		c := make(chan kexResultErr, 1)
-		var magics handshakeMagics
-		go func() {
-			r, e := kex.Client(a, rand.Reader, &magics)
-			a.Close()
-			c <- kexResultErr{r, e}
-		}()
-		go func() {
-			r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
-			b.Close()
-			s <- kexResultErr{r, e}
-		}()
+					s := make(chan kexResultErr, 1)
+					c := make(chan kexResultErr, 1)
+					var magics handshakeMagics
+					go func() {
+						r, e := kex.Client(a, rand.Reader, &magics)
+						a.Close()
+						c <- kexResultErr{r, e}
+					}()
+					go func() {
+						r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
+						b.Close()
+						s <- kexResultErr{r, e}
+					}()
 
-		clientRes := <-c
-		serverRes := <-s
-		if clientRes.err != nil {
-			t.Errorf("client: %v", clientRes.err)
-		}
-		if serverRes.err != nil {
-			t.Errorf("server: %v", serverRes.err)
-		}
-		if !reflect.DeepEqual(clientRes.result, serverRes.result) {
-			t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
-		}
+					clientRes := <-c
+					serverRes := <-s
+					if clientRes.err != nil {
+						t.Errorf("client: %v", clientRes.err)
+					}
+					if serverRes.err != nil {
+						t.Errorf("server: %v", serverRes.err)
+					}
+					if !reflect.DeepEqual(clientRes.result, serverRes.result) {
+						t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result)
+					}
+				}()
+			}
+			wg.Wait()
+		})
 	}
 }