blob: 23f729dd94bdfd866f8689e98caf930e1effe3ed [file] [log] [blame]
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"rand"
"reflect"
"testing"
"testing/quick"
)
var tests = []interface{}{
&clientHelloMsg{},
&serverHelloMsg{},
&certificateMsg{},
&certificateRequestMsg{},
&certificateVerifyMsg{},
&certificateStatusMsg{},
&clientKeyExchangeMsg{},
&finishedMsg{},
&nextProtoMsg{},
}
type testMessage interface {
marshal() []byte
unmarshal([]byte) bool
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for i, iface := range tests {
ty := reflect.ValueOf(iface).Type()
n := 100
if testing.Short() {
n = 5
}
for j := 0; j < n; j++ {
v, ok := quick.Value(ty, rand)
if !ok {
t.Errorf("#%d: failed to create value", i)
break
}
m1 := v.Interface().(testMessage)
marshaled := m1.marshal()
m2 := iface.(testMessage)
if !m2.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
m2.marshal() // to fill any marshal cache in the message
if !reflect.DeepEqual(m1, m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break
}
if i >= 2 {
// The first two message types (ClientHello and
// ServerHello) are allowed to have parsable
// prefixes because the extension data is
// optional.
for j := 0; j < len(marshaled); j++ {
if m2.unmarshal(marshaled[0:j]) {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break
}
}
}
}
}
}
func TestFuzz(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for _, iface := range tests {
m := iface.(testMessage)
for j := 0; j < 1000; j++ {
len := rand.Intn(100)
bytes := randomBytes(len, rand)
// This just looks for crashes due to bounds errors etc.
m.unmarshal(bytes)
}
}
}
func randomBytes(n int, rand *rand.Rand) []byte {
r := make([]byte, n)
for i := 0; i < n; i++ {
r[i] = byte(rand.Int31())
}
return r
}
func randomString(n int, rand *rand.Rand) string {
b := randomBytes(n, rand)
return string(b)
}
func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuites = make([]uint16, rand.Intn(63)+1)
for i := 0; i < len(m.cipherSuites); i++ {
m.cipherSuites[i] = uint16(rand.Int31())
}
m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
if rand.Intn(10) > 5 {
m.nextProtoNeg = true
}
if rand.Intn(10) > 5 {
m.serverName = randomString(rand.Intn(255), rand)
}
m.ocspStapling = rand.Intn(10) > 5
m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
m.supportedCurves = make([]uint16, rand.Intn(5)+1)
for i := range m.supportedCurves {
m.supportedCurves[i] = uint16(rand.Intn(30000))
}
return reflect.ValueOf(m)
}
func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.sessionId = randomBytes(rand.Intn(32), rand)
m.cipherSuite = uint16(rand.Int31())
m.compressionMethod = uint8(rand.Intn(256))
if rand.Intn(10) > 5 {
m.nextProtoNeg = true
n := rand.Intn(10)
m.nextProtos = make([]string, n)
for i := 0; i < n; i++ {
m.nextProtos[i] = randomString(20, rand)
}
}
return reflect.ValueOf(m)
}
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg{}
numCerts := rand.Intn(20)
m.certificates = make([][]byte, numCerts)
for i := 0; i < numCerts; i++ {
m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsg{}
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
numCAs := rand.Intn(100)
m.certificateAuthorities = make([][]byte, numCAs)
for i := 0; i < numCAs; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateVerifyMsg{}
m.signature = randomBytes(rand.Intn(15)+1, rand)
return reflect.ValueOf(m)
}
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{}
if rand.Intn(10) > 5 {
m.statusType = statusTypeOCSP
m.response = randomBytes(rand.Intn(10)+1, rand)
} else {
m.statusType = 42
}
return reflect.ValueOf(m)
}
func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &clientKeyExchangeMsg{}
m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
return reflect.ValueOf(m)
}
func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &finishedMsg{}
m.verifyData = randomBytes(12, rand)
return reflect.ValueOf(m)
}
func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &nextProtoMsg{}
m.proto = randomString(rand.Intn(255), rand)
return reflect.ValueOf(m)
}