ssh: don't assume packet plaintext size

When reading GCM and ChaChaPoly1305 packets, don't make assumptions
about the size of the enciphered plaintext. This fixes two panics
caused by standards non-compliant malformed packets.

Thanks to Rod Hynes, Psiphon Inc. for reporting this issue.

Fixes golang/go#49932
Fixes CVE-2021-43565

Change-Id: I660cff39d197e0d04ec44d11d792b22d954df2ef
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1262659
Reviewed-by: Katie Hockman <katiehockman@google.com>
Reviewed-by: Julie Qiu <julieqiu@google.com>
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/368814
Trust: Roland Shoemaker <roland@golang.org>
Trust: Katie Hockman <katie@golang.org>
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Julie Qiu <julie@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
diff --git a/ssh/cipher.go b/ssh/cipher.go
index bddbde5..f8bdf49 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -394,6 +394,10 @@
 	}
 	c.incIV()
 
+	if len(plain) == 0 {
+		return nil, errors.New("ssh: empty packet")
+	}
+
 	padding := plain[0]
 	if padding < 4 {
 		// padding is a byte, so it automatically satisfies
@@ -710,6 +714,10 @@
 	plain := c.buf[4:contentEnd]
 	s.XORKeyStream(plain, plain)
 
+	if len(plain) == 0 {
+		return nil, errors.New("ssh: empty packet")
+	}
+
 	padding := plain[0]
 	if padding < 4 {
 		// padding is a byte, so it automatically satisfies
diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go
index 70a2b5b..6109828 100644
--- a/ssh/cipher_test.go
+++ b/ssh/cipher_test.go
@@ -8,7 +8,12 @@
 	"bytes"
 	"crypto"
 	"crypto/rand"
+	"encoding/binary"
+	"io"
 	"testing"
+
+	"golang.org/x/crypto/chacha20"
+	"golang.org/x/crypto/internal/poly1305"
 )
 
 func TestDefaultCiphersExist(t *testing.T) {
@@ -129,3 +134,98 @@
 		lastRead = bytesRead
 	}
 }
+
+func TestCVE202143565(t *testing.T) {
+	tests := []struct {
+		cipher          string
+		constructPacket func(packetCipher) io.Reader
+	}{
+		{
+			cipher: gcmCipherID,
+			constructPacket: func(client packetCipher) io.Reader {
+				internalCipher := client.(*gcmCipher)
+				b := &bytes.Buffer{}
+				prefix := [4]byte{}
+				if _, err := b.Write(prefix[:]); err != nil {
+					t.Fatal(err)
+				}
+				internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:])
+				if _, err := b.Write(internalCipher.buf); err != nil {
+					t.Fatal(err)
+				}
+				internalCipher.incIV()
+
+				return b
+			},
+		},
+		{
+			cipher: chacha20Poly1305ID,
+			constructPacket: func(client packetCipher) io.Reader {
+				internalCipher := client.(*chacha20Poly1305Cipher)
+				b := &bytes.Buffer{}
+
+				nonce := make([]byte, 12)
+				s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce)
+				if err != nil {
+					t.Fatal(err)
+				}
+				var polyKey, discardBuf [32]byte
+				s.XORKeyStream(polyKey[:], polyKey[:])
+				s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
+
+				internalCipher.buf = make([]byte, 4+poly1305.TagSize)
+				binary.BigEndian.PutUint32(internalCipher.buf, 0)
+				ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce)
+				if err != nil {
+					t.Fatal(err)
+				}
+				ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4])
+				if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil {
+					t.Fatal(err)
+				}
+
+				s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4])
+
+				var tag [poly1305.TagSize]byte
+				poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey)
+
+				copy(internalCipher.buf[4:], tag[:])
+
+				if _, err := b.Write(internalCipher.buf); err != nil {
+					t.Fatal(err)
+				}
+
+				return b
+			},
+		},
+	}
+
+	for _, tc := range tests {
+		mac := "hmac-sha2-256"
+
+		kr := &kexResult{Hash: crypto.SHA1}
+		algs := directionAlgorithms{
+			Cipher:      tc.cipher,
+			MAC:         mac,
+			Compression: "none",
+		}
+		client, err := newPacketCipher(clientKeys, algs, kr)
+		if err != nil {
+			t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err)
+		}
+		server, err := newPacketCipher(clientKeys, algs, kr)
+		if err != nil {
+			t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err)
+		}
+
+		b := tc.constructPacket(client)
+
+		wantErr := "ssh: empty packet"
+		_, err = server.readCipherPacket(0, b)
+		if err == nil {
+			t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac)
+		} else if err.Error() != wantErr {
+			t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr)
+		}
+	}
+}