internal/wycheproof: add ECDH tests, including point decompression

Fixes golang/go#38936

Change-Id: I231d30fcc683abd9efb36b6fd9cc05f599078ade
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/396174
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
diff --git a/internal/wycheproof/ecdh_test.go b/internal/wycheproof/ecdh_test.go
new file mode 100644
index 0000000..a3918ba
--- /dev/null
+++ b/internal/wycheproof/ecdh_test.go
@@ -0,0 +1,163 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package wycheproof
+
+import (
+	"bytes"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/x509"
+	"encoding/asn1"
+	"errors"
+	"fmt"
+	"testing"
+
+	"golang.org/x/crypto/cryptobyte"
+	casn1 "golang.org/x/crypto/cryptobyte/asn1"
+)
+
+func TestECDH(t *testing.T) {
+	type ECDHTestVector struct {
+		// A brief description of the test case
+		Comment string `json:"comment,omitempty"`
+		// A list of flags
+		Flags []string `json:"flags,omitempty"`
+		// the private key
+		Private string `json:"private,omitempty"`
+		// Encoded public key
+		Public string `json:"public,omitempty"`
+		// Test result
+		Result string `json:"result,omitempty"`
+		// The shared secret key
+		Shared string `json:"shared,omitempty"`
+		// Identifier of the test case
+		TcID int `json:"tcId,omitempty"`
+	}
+
+	type ECDHTestGroup struct {
+		Curve string            `json:"curve,omitempty"`
+		Tests []*ECDHTestVector `json:"tests,omitempty"`
+	}
+
+	type Root struct {
+		TestGroups []*ECDHTestGroup `json:"testGroups,omitempty"`
+	}
+
+	flagsShouldPass := map[string]bool{
+		// ParsePKIXPublicKey doesn't support compressed points, but we test
+		// them against UnmarshalCompressed anyway.
+		"CompressedPoint": true,
+		// We don't support decoding custom curves.
+		"UnnamedCurve": false,
+		// WrongOrder and UnusedParam are only found with UnnamedCurve.
+		"WrongOrder":  false,
+		"UnusedParam": false,
+	}
+
+	// supportedCurves is a map of all elliptic curves supported
+	// by crypto/elliptic, which can subsequently be parsed and tested.
+	supportedCurves := map[string]bool{
+		"secp224r1": true,
+		"secp256r1": true,
+		"secp384r1": true,
+		"secp521r1": true,
+	}
+
+	var root Root
+	readTestVector(t, "ecdh_test.json", &root)
+	for _, tg := range root.TestGroups {
+		if !supportedCurves[tg.Curve] {
+			continue
+		}
+		for _, tt := range tg.Tests {
+			tg, tt := tg, tt
+			t.Run(fmt.Sprintf("%s/%d", tg.Curve, tt.TcID), func(t *testing.T) {
+				t.Logf("Type: %v", tt.Result)
+				t.Logf("Flags: %q", tt.Flags)
+				t.Log(tt.Comment)
+
+				shouldPass := shouldPass(tt.Result, tt.Flags, flagsShouldPass)
+
+				p := decodeHex(tt.Public)
+				pp, err := x509.ParsePKIXPublicKey(p)
+				if err != nil {
+					pp, err = decodeCompressedPKIX(p)
+				}
+				if err != nil {
+					if shouldPass {
+						t.Errorf("unexpected parsing error: %s", err)
+					}
+					return
+				}
+				pub := pp.(*ecdsa.PublicKey)
+
+				priv := decodeHex(tt.Private)
+				shared := decodeHex(tt.Shared)
+
+				x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, priv)
+				xBytes := make([]byte, (pub.Curve.Params().BitSize+7)/8)
+				got := bytes.Equal(shared, x.FillBytes(xBytes))
+
+				if want := shouldPass; got != want {
+					t.Errorf("wanted success %v, got %v", want, got)
+				}
+			})
+		}
+	}
+}
+
+func decodeCompressedPKIX(der []byte) (interface{}, error) {
+	s := cryptobyte.String(der)
+	var s1, s2 cryptobyte.String
+	var algoOID, namedCurveOID asn1.ObjectIdentifier
+	var pointDER []byte
+	if !s.ReadASN1(&s1, casn1.SEQUENCE) || !s.Empty() ||
+		!s1.ReadASN1(&s2, casn1.SEQUENCE) ||
+		!s2.ReadASN1ObjectIdentifier(&algoOID) ||
+		!s2.ReadASN1ObjectIdentifier(&namedCurveOID) || !s2.Empty() ||
+		!s1.ReadASN1BitStringAsBytes(&pointDER) || !s1.Empty() {
+		return nil, errors.New("failed to parse PKIX structure")
+	}
+
+	if !algoOID.Equal(oidPublicKeyECDSA) {
+		return nil, errors.New("wrong algorithm OID")
+	}
+	namedCurve := namedCurveFromOID(namedCurveOID)
+	if namedCurve == nil {
+		return nil, errors.New("unsupported elliptic curve")
+	}
+	x, y := elliptic.UnmarshalCompressed(namedCurve, pointDER)
+	if x == nil {
+		return nil, errors.New("failed to unmarshal elliptic curve point")
+	}
+	pub := &ecdsa.PublicKey{
+		Curve: namedCurve,
+		X:     x,
+		Y:     y,
+	}
+	return pub, nil
+}
+
+var (
+	oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
+	oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33}
+	oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7}
+	oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34}
+	oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35}
+)
+
+func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve {
+	switch {
+	case oid.Equal(oidNamedCurveP224):
+		return elliptic.P224()
+	case oid.Equal(oidNamedCurveP256):
+		return elliptic.P256()
+	case oid.Equal(oidNamedCurveP384):
+		return elliptic.P384()
+	case oid.Equal(oidNamedCurveP521):
+		return elliptic.P521()
+	}
+	return nil
+}
diff --git a/internal/wycheproof/ecdsa_compat_test.go b/internal/wycheproof/ecdsa_compat_test.go
deleted file mode 100644
index 5880fb3..0000000
--- a/internal/wycheproof/ecdsa_compat_test.go
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.15
-// +build !go1.15
-
-// ecdsa.VerifyASN1 was added in Go 1.15.
-
-package wycheproof
-
-import (
-	"crypto/ecdsa"
-	"math/big"
-
-	"golang.org/x/crypto/cryptobyte"
-	"golang.org/x/crypto/cryptobyte/asn1"
-)
-
-func verifyASN1(pub *ecdsa.PublicKey, hash, sig []byte) bool {
-	var (
-		r, s  = &big.Int{}, &big.Int{}
-		inner cryptobyte.String
-	)
-	input := cryptobyte.String(sig)
-	if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
-		!input.Empty() ||
-		!inner.ReadASN1Integer(r) ||
-		!inner.ReadASN1Integer(s) ||
-		!inner.Empty() {
-		return false
-	}
-	return ecdsa.Verify(pub, hash, r, s)
-}
diff --git a/internal/wycheproof/ecdsa_go115_test.go b/internal/wycheproof/ecdsa_go115_test.go
deleted file mode 100644
index e13e709..0000000
--- a/internal/wycheproof/ecdsa_go115_test.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.15
-// +build go1.15
-
-package wycheproof
-
-import (
-	"crypto/ecdsa"
-)
-
-func verifyASN1(pub *ecdsa.PublicKey, hash, sig []byte) bool {
-	return ecdsa.VerifyASN1(pub, hash, sig)
-}
diff --git a/internal/wycheproof/ecdsa_test.go b/internal/wycheproof/ecdsa_test.go
index 81731f7..42f3285 100644
--- a/internal/wycheproof/ecdsa_test.go
+++ b/internal/wycheproof/ecdsa_test.go
@@ -9,126 +9,47 @@
 	"testing"
 )
 
-func TestEcdsa(t *testing.T) {
-	// AsnSignatureTestVector
-	type AsnSignatureTestVector struct {
-
+func TestECDSA(t *testing.T) {
+	type ASNSignatureTestVector struct {
 		// A brief description of the test case
-		Comment string `json:"comment,omitempty"`
-
+		Comment string `json:"comment"`
 		// A list of flags
-		Flags []string `json:"flags,omitempty"`
-
+		Flags []string `json:"flags"`
 		// The message to sign
-		Msg string `json:"msg,omitempty"`
-
+		Msg string `json:"msg"`
 		// Test result
-		Result string `json:"result,omitempty"`
-
-		// An ASN encoded signature for msg
-		Sig string `json:"sig,omitempty"`
-
+		Result string `json:"result"`
+		// An ASN.1 encoded signature for msg
+		Sig string `json:"sig"`
 		// Identifier of the test case
-		TcId int `json:"tcId,omitempty"`
+		TcID int `json:"tcId"`
 	}
 
-	// EcPublicKey
-	type EcPublicKey struct {
-
-		// the EC group used by this public key
-		Curve interface{} `json:"curve,omitempty"`
-
-		// the key size in bits
-		KeySize int `json:"keySize,omitempty"`
-
-		// the key type
-		Type string `json:"type,omitempty"`
-
-		// encoded public key point
-		Uncompressed string `json:"uncompressed,omitempty"`
-
-		// the x-coordinate of the public key point
-		Wx string `json:"wx,omitempty"`
-
-		// the y-coordinate of the public key point
-		Wy string `json:"wy,omitempty"`
+	type ECPublicKey struct {
+		// The EC group used by this public key
+		Curve interface{} `json:"curve"`
 	}
 
-	// EcUnnamedGroup
-	type EcUnnamedGroup struct {
-
-		// coefficient a of the elliptic curve equation
-		A string `json:"a,omitempty"`
-
-		// coefficient b of the elliptic curve equation
-		B string `json:"b,omitempty"`
-
-		// the x-coordinate of the generator
-		Gx string `json:"gx,omitempty"`
-
-		// the y-coordinate of the generator
-		Gy string `json:"gy,omitempty"`
-
-		// the cofactor
-		H int `json:"h,omitempty"`
-
-		// the order of the generator
-		N string `json:"n,omitempty"`
-
-		// the order of the underlying field
-		P string `json:"p,omitempty"`
-
-		// an unnamed EC group over a prime field in Weierstrass form
-		Type string `json:"type,omitempty"`
-	}
-
-	// EcdsaTestGroup
-	type EcdsaTestGroup struct {
-
-		// unenocded EC public key
-		Key *EcPublicKey `json:"key,omitempty"`
-
+	type ECDSATestGroup struct {
+		// Unencoded EC public key
+		Key *ECPublicKey `json:"key"`
 		// DER encoded public key
-		KeyDer string `json:"keyDer,omitempty"`
-
-		// Pem encoded public key
-		KeyPem string `json:"keyPem,omitempty"`
-
+		KeyDER string `json:"keyDer"`
 		// the hash function used for ECDSA
-		Sha   string                    `json:"sha,omitempty"`
-		Tests []*AsnSignatureTestVector `json:"tests,omitempty"`
-		Type  interface{}               `json:"type,omitempty"`
+		SHA   string                    `json:"sha"`
+		Tests []*ASNSignatureTestVector `json:"tests"`
 	}
 
-	// Notes a description of the labels used in the test vectors
-	type Notes struct {
-	}
-
-	// Root
 	type Root struct {
-
-		// the primitive tested in the test file
-		Algorithm string `json:"algorithm,omitempty"`
-
-		// the version of the test vectors.
-		GeneratorVersion string `json:"generatorVersion,omitempty"`
-
-		// additional documentation
-		Header []string `json:"header,omitempty"`
-
-		// a description of the labels used in the test vectors
-		Notes *Notes `json:"notes,omitempty"`
-
-		// the number of test vectors in this test
-		NumberOfTests int               `json:"numberOfTests,omitempty"`
-		Schema        interface{}       `json:"schema,omitempty"`
-		TestGroups    []*EcdsaTestGroup `json:"testGroups,omitempty"`
+		TestGroups []*ECDSATestGroup `json:"testGroups"`
 	}
 
 	flagsShouldPass := map[string]bool{
-		// An encoded ASN.1 integer missing a leading zero is invalid, but accepted by some implementations.
+		// An encoded ASN.1 integer missing a leading zero is invalid, but
+		// accepted by some implementations.
 		"MissingZero": false,
-		// A signature using a weaker hash than the EC params is not a security risk, as long as the hash is secure.
+		// A signature using a weaker hash than the EC params is not a security
+		// risk, as long as the hash is secure.
 		// https://www.imperialviolet.org/2014/05/25/strengthmatching.html
 		"WeakHash": true,
 	}
@@ -149,15 +70,15 @@
 		if !supportedCurves[curve] {
 			continue
 		}
-		pub := decodePublicKey(tg.KeyDer).(*ecdsa.PublicKey)
-		h := parseHash(tg.Sha).New()
+		pub := decodePublicKey(tg.KeyDER).(*ecdsa.PublicKey)
+		h := parseHash(tg.SHA).New()
 		for _, sig := range tg.Tests {
 			h.Reset()
 			h.Write(decodeHex(sig.Msg))
 			hashed := h.Sum(nil)
-			got := verifyASN1(pub, hashed, decodeHex(sig.Sig))
+			got := ecdsa.VerifyASN1(pub, hashed, decodeHex(sig.Sig))
 			if want := shouldPass(sig.Result, sig.Flags, flagsShouldPass); got != want {
-				t.Errorf("tcid: %d, type: %s, comment: %q, wanted success: %t", sig.TcId, sig.Result, sig.Comment, want)
+				t.Errorf("tcid: %d, type: %s, comment: %q, wanted success: %t", sig.TcID, sig.Result, sig.Comment, want)
 			}
 		}
 	}