internal/wycheproof: add generic AEAD test

Add a generic AEAD test that exercises the vectors for AES GCM,
ChaCha20Poly-1305, and XChaCha20-Poly1305. Removes the existing
chacha20_poly1305_test.go test.

Change-Id: Icfaba30f8db2a1e32a9459c98cd3af5d63052027
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/234688
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
diff --git a/internal/wycheproof/aead_test.go b/internal/wycheproof/aead_test.go
new file mode 100644
index 0000000..292d854
--- /dev/null
+++ b/internal/wycheproof/aead_test.go
@@ -0,0 +1,176 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package wycheproof
+
+import (
+	"bytes"
+	"crypto/aes"
+	"crypto/cipher"
+	"fmt"
+	"testing"
+
+	"golang.org/x/crypto/chacha20poly1305"
+)
+
+func TestAEAD(t *testing.T) {
+	// AeadTestVector
+	type AeadTestVector struct {
+
+		// additional authenticated data
+		Aad string `json:"aad,omitempty"`
+
+		// A brief description of the test case
+		Comment string `json:"comment,omitempty"`
+
+		// the ciphertext (without iv and tag)
+		Ct string `json:"ct,omitempty"`
+
+		// A list of flags
+		Flags []string `json:"flags,omitempty"`
+
+		// the nonce
+		Iv string `json:"iv,omitempty"`
+
+		// the key
+		Key string `json:"key,omitempty"`
+
+		// the plaintext
+		Msg string `json:"msg,omitempty"`
+
+		// Test result
+		Result string `json:"result,omitempty"`
+
+		// the authentication tag
+		Tag string `json:"tag,omitempty"`
+
+		// Identifier of the test case
+		TcId int `json:"tcId,omitempty"`
+	}
+
+	// Notes a description of the labels used in the test vectors
+	type Notes struct {
+	}
+
+	// AeadTestGroup
+	type AeadTestGroup struct {
+
+		// the IV size in bits
+		IvSize int `json:"ivSize,omitempty"`
+
+		// the keySize in bits
+		KeySize int `json:"keySize,omitempty"`
+
+		// the expected size of the tag in bits
+		TagSize int               `json:"tagSize,omitempty"`
+		Tests   []*AeadTestVector `json:"tests,omitempty"`
+		Type    interface{}       `json:"type,omitempty"`
+	}
+
+	// Root
+	type Root struct {
+
+		// the primitive tested in the test file
+		Algorithm string `json:"algorithm,omitempty"`
+
+		// the version of the test vectors.
+		GeneratorVersion string `json:"generatorVersion,omitempty"`
+
+		// additional documentation
+		Header []string `json:"header,omitempty"`
+
+		// a description of the labels used in the test vectors
+		Notes *Notes `json:"notes,omitempty"`
+
+		// the number of test vectors in this test
+		NumberOfTests int              `json:"numberOfTests,omitempty"`
+		Schema        interface{}      `json:"schema,omitempty"`
+		TestGroups    []*AeadTestGroup `json:"testGroups,omitempty"`
+	}
+
+	testSealOpen := func(t *testing.T, aead cipher.AEAD, tv *AeadTestVector, recoverBadNonce func()) {
+		defer recoverBadNonce()
+
+		iv, tag, ct, msg, aad := decodeHex(tv.Iv), decodeHex(tv.Tag), decodeHex(tv.Ct), decodeHex(tv.Msg), decodeHex(tv.Aad)
+
+		genCT := aead.Seal(nil, iv, msg, aad)
+		genMsg, err := aead.Open(nil, iv, genCT, aad)
+		if err != nil {
+			t.Errorf("failed to decrypt generated ciphertext: %s", err)
+		}
+		if !bytes.Equal(genMsg, msg) {
+			t.Errorf("unexpected roundtripped plaintext: got %x, want %x", genMsg, msg)
+		}
+
+		ctWithTag := append(ct, tag...)
+		msg2, err := aead.Open(nil, iv, ctWithTag, aad)
+		wantPass := shouldPass(tv.Result, tv.Flags, nil)
+		if !wantPass && err == nil {
+			t.Error("decryption succeeded when it should've failed")
+		} else if wantPass {
+			if err != nil {
+				t.Fatalf("decryption failed: %s", err)
+			}
+			if !bytes.Equal(genCT, ctWithTag) {
+				t.Errorf("generated ciphertext doesn't match expected: got %x, want %x", genCT, ctWithTag)
+			}
+			if !bytes.Equal(msg, msg2) {
+				t.Errorf("decrypted ciphertext doesn't match expected: got %x, want %x", msg2, msg)
+			}
+		}
+	}
+
+	vectors := map[string]func(*testing.T, []byte) cipher.AEAD{
+		"aes_gcm_test.json": func(t *testing.T, key []byte) cipher.AEAD {
+			aesCipher, err := aes.NewCipher(key)
+			if err != nil {
+				t.Fatalf("failed to construct cipher: %s", err)
+			}
+			aead, err := cipher.NewGCM(aesCipher)
+			if err != nil {
+				t.Fatalf("failed to construct cipher: %s", err)
+			}
+			return aead
+		},
+		"chacha20_poly1305_test.json": func(t *testing.T, key []byte) cipher.AEAD {
+			aead, err := chacha20poly1305.New(key)
+			if err != nil {
+				t.Fatalf("failed to construct cipher: %s", err)
+			}
+			return aead
+		},
+		"xchacha20_poly1305_test.json": func(t *testing.T, key []byte) cipher.AEAD {
+			aead, err := chacha20poly1305.NewX(key)
+			if err != nil {
+				t.Fatalf("failed to construct cipher: %s", err)
+			}
+			return aead
+		},
+	}
+	for file, cipherInit := range vectors {
+		var root Root
+		readTestVector(t, file, &root)
+		for _, tg := range root.TestGroups {
+			for _, tv := range tg.Tests {
+				testName := fmt.Sprintf("%s #%d", file, tv.TcId)
+				if tv.Comment != "" {
+					testName += fmt.Sprintf(" %s", tv.Comment)
+				}
+				t.Run(testName, func(t *testing.T) {
+					aead := cipherInit(t, decodeHex(tv.Key))
+					testSealOpen(t, aead, tv, func() {
+						// A bad nonce causes a panic in AEAD.Seal and AEAD.Open,
+						// so should be recovered. Fail the test if it broke for
+						// some other reason.
+						if r := recover(); r != nil {
+							if tg.IvSize/8 == aead.NonceSize() {
+								t.Error("unexpected panic")
+							}
+						}
+					})
+				})
+			}
+		}
+	}
+}
diff --git a/internal/wycheproof/chacha20_poly1305_test.go b/internal/wycheproof/chacha20_poly1305_test.go
deleted file mode 100644
index 7fedbc4..0000000
--- a/internal/wycheproof/chacha20_poly1305_test.go
+++ /dev/null
@@ -1,148 +0,0 @@
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package wycheproof
-
-import (
-	"crypto/cipher"
-	"encoding/hex"
-	"testing"
-
-	"golang.org/x/crypto/chacha20poly1305"
-)
-
-func TestChaCha20Poly1305(t *testing.T) {
-	// AeadTestVector
-	type AeadTestVector struct {
-
-		// additional authenticated data
-		Aad string `json:"aad,omitempty"`
-
-		// A brief description of the test case
-		Comment string `json:"comment,omitempty"`
-
-		// the ciphertext (without iv and tag)
-		Ct string `json:"ct,omitempty"`
-
-		// A list of flags
-		Flags []string `json:"flags,omitempty"`
-
-		// the nonce
-		Iv string `json:"iv,omitempty"`
-
-		// the key
-		Key string `json:"key,omitempty"`
-
-		// the plaintext
-		Msg string `json:"msg,omitempty"`
-
-		// Test result
-		Result string `json:"result,omitempty"`
-
-		// the authentication tag
-		Tag string `json:"tag,omitempty"`
-
-		// Identifier of the test case
-		TcId int `json:"tcId,omitempty"`
-	}
-
-	// Notes a description of the labels used in the test vectors
-	type Notes struct {
-	}
-
-	// AeadTestGroup
-	type AeadTestGroup struct {
-
-		// the IV size in bits
-		IvSize int `json:"ivSize,omitempty"`
-
-		// the keySize in bits
-		KeySize int `json:"keySize,omitempty"`
-
-		// the expected size of the tag in bits
-		TagSize int               `json:"tagSize,omitempty"`
-		Tests   []*AeadTestVector `json:"tests,omitempty"`
-		Type    interface{}       `json:"type,omitempty"`
-	}
-
-	// Root
-	type Root struct {
-
-		// the primitive tested in the test file
-		Algorithm string `json:"algorithm,omitempty"`
-
-		// the version of the test vectors.
-		GeneratorVersion string `json:"generatorVersion,omitempty"`
-
-		// additional documentation
-		Header []string `json:"header,omitempty"`
-
-		// a description of the labels used in the test vectors
-		Notes *Notes `json:"notes,omitempty"`
-
-		// the number of test vectors in this test
-		NumberOfTests int              `json:"numberOfTests,omitempty"`
-		Schema        interface{}      `json:"schema,omitempty"`
-		TestGroups    []*AeadTestGroup `json:"testGroups,omitempty"`
-	}
-
-	testAeadSealOpen := func(t *testing.T, aead cipher.AEAD, tv *AeadTestVector, recoverBadNonce func()) {
-		defer recoverBadNonce()
-
-		// Encrypt the message, then decrypt the new ciphertext and validate
-		// the decrypted message.
-		ciphertext := aead.Seal(nil, decodeHex(tv.Iv), decodeHex(tv.Msg), decodeHex(tv.Aad))
-		msg, err := aead.Open(nil, decodeHex(tv.Iv), ciphertext, decodeHex(tv.Aad))
-		if err != nil {
-			t.Fatalf("#%d: decryption failed: %v", tv.TcId, err)
-		}
-		if got, want := hex.EncodeToString(msg), tv.Msg; got != want {
-			t.Errorf("#%d: bad message after encrypting and decrypting: %s, want %v", tv.TcId, got, want)
-		}
-
-		// Decrypt the provided ciphertext and validate the decrypted message.
-		tv.Ct += tv.Tag // append the tag to the ciphertext
-		msg2, err := aead.Open(nil, decodeHex(tv.Iv), decodeHex(tv.Ct), decodeHex(tv.Aad))
-		wantPass := shouldPass(tv.Result, tv.Flags, nil)
-		if wantPass {
-			if err != nil {
-				t.Errorf("#%d, type: %s, comment: %q, decryption wanted success, got err: %v", tv.TcId, tv.Result, tv.Comment, err)
-			}
-			if got, want := hex.EncodeToString(ciphertext), tv.Ct; got != want {
-				t.Errorf("#%d: ciphertext doesn't match: %s, want=%s", tv.TcId, got, want)
-			}
-			if got, want := hex.EncodeToString(msg2), tv.Msg; got != want {
-				t.Errorf("#%d: bad message after decrypting ciphertext: %s, want %v", tv.TcId, got, want)
-			}
-		} else {
-			if err == nil {
-				t.Errorf("#%d, type: %s, comment: %q, decryption wanted error", tv.TcId, tv.Result, tv.Comment)
-			}
-		}
-	}
-
-	var root Root
-	readTestVector(t, "chacha20_poly1305_test.json", &root)
-	for _, tg := range root.TestGroups {
-		for _, tv := range tg.Tests {
-			aead, err := chacha20poly1305.New(decodeHex(tv.Key))
-			if err != nil {
-				t.Fatalf("#%d: %v", tv.TcId, err)
-			}
-			if tg.TagSize/8 != aead.Overhead() {
-				t.Fatalf("#%d: bad tag length", tv.TcId)
-			}
-			testAeadSealOpen(t, aead, tv, func() {
-				// A bad nonce causes a panic in AEAD.Seal and AEAD.Open,
-				// so should be recovered. Fail the test if it broke for
-				// some other reason.
-				if r := recover(); r != nil {
-					if tg.IvSize/8 == chacha20poly1305.NonceSize {
-						t.Errorf("#%d: unexpected panic", tv.TcId)
-					}
-				}
-			})
-		}
-	}
-}