|  | // Copyright 2009 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package subtle | 
|  |  | 
|  | import ( | 
|  | "testing" | 
|  | "testing/quick" | 
|  | ) | 
|  |  | 
|  | type TestConstantTimeCompareStruct struct { | 
|  | a, b []byte | 
|  | out  int | 
|  | } | 
|  |  | 
|  | var testConstantTimeCompareData = []TestConstantTimeCompareStruct{ | 
|  | {[]byte{}, []byte{}, 1}, | 
|  | {[]byte{0x11}, []byte{0x11}, 1}, | 
|  | {[]byte{0x12}, []byte{0x11}, 0}, | 
|  | {[]byte{0x11}, []byte{0x11, 0x12}, 0}, | 
|  | {[]byte{0x11, 0x12}, []byte{0x11}, 0}, | 
|  | } | 
|  |  | 
|  | func TestConstantTimeCompare(t *testing.T) { | 
|  | for i, test := range testConstantTimeCompareData { | 
|  | if r := ConstantTimeCompare(test.a, test.b); r != test.out { | 
|  | t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | type TestConstantTimeByteEqStruct struct { | 
|  | a, b uint8 | 
|  | out  int | 
|  | } | 
|  |  | 
|  | var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{ | 
|  | {0, 0, 1}, | 
|  | {0, 1, 0}, | 
|  | {1, 0, 0}, | 
|  | {0xff, 0xff, 1}, | 
|  | {0xff, 0xfe, 0}, | 
|  | } | 
|  |  | 
|  | func byteEq(a, b uint8) int { | 
|  | if a == b { | 
|  | return 1 | 
|  | } | 
|  | return 0 | 
|  | } | 
|  |  | 
|  | func TestConstantTimeByteEq(t *testing.T) { | 
|  | for i, test := range testConstandTimeByteEqData { | 
|  | if r := ConstantTimeByteEq(test.a, test.b); r != test.out { | 
|  | t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out) | 
|  | } | 
|  | } | 
|  | err := quick.CheckEqual(ConstantTimeByteEq, byteEq, nil) | 
|  | if err != nil { | 
|  | t.Error(err) | 
|  | } | 
|  | } | 
|  |  | 
|  | func eq(a, b int32) int { | 
|  | if a == b { | 
|  | return 1 | 
|  | } | 
|  | return 0 | 
|  | } | 
|  |  | 
|  | func TestConstantTimeEq(t *testing.T) { | 
|  | err := quick.CheckEqual(ConstantTimeEq, eq, nil) | 
|  | if err != nil { | 
|  | t.Error(err) | 
|  | } | 
|  | } | 
|  |  | 
|  | func makeCopy(v int, x, y []byte) []byte { | 
|  | if len(x) > len(y) { | 
|  | x = x[0:len(y)] | 
|  | } else { | 
|  | y = y[0:len(x)] | 
|  | } | 
|  | if v == 1 { | 
|  | copy(x, y) | 
|  | } | 
|  | return x | 
|  | } | 
|  |  | 
|  | func constantTimeCopyWrapper(v int, x, y []byte) []byte { | 
|  | if len(x) > len(y) { | 
|  | x = x[0:len(y)] | 
|  | } else { | 
|  | y = y[0:len(x)] | 
|  | } | 
|  | v &= 1 | 
|  | ConstantTimeCopy(v, x, y) | 
|  | return x | 
|  | } | 
|  |  | 
|  | func TestConstantTimeCopy(t *testing.T) { | 
|  | err := quick.CheckEqual(constantTimeCopyWrapper, makeCopy, nil) | 
|  | if err != nil { | 
|  | t.Error(err) | 
|  | } | 
|  | } | 
|  |  | 
|  | var lessOrEqTests = []struct { | 
|  | x, y, result int | 
|  | }{ | 
|  | {0, 0, 1}, | 
|  | {1, 0, 0}, | 
|  | {0, 1, 1}, | 
|  | {10, 20, 1}, | 
|  | {20, 10, 0}, | 
|  | {10, 10, 1}, | 
|  | } | 
|  |  | 
|  | func TestConstantTimeLessOrEq(t *testing.T) { | 
|  | for i, test := range lessOrEqTests { | 
|  | result := ConstantTimeLessOrEq(test.x, test.y) | 
|  | if result != test.result { | 
|  | t.Errorf("#%d: %d <= %d gave %d, expected %d", i, test.x, test.y, result, test.result) | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | var benchmarkGlobal uint8 | 
|  |  | 
|  | func BenchmarkConstantTimeByteEq(b *testing.B) { | 
|  | var x, y uint8 | 
|  |  | 
|  | for i := 0; i < b.N; i++ { | 
|  | x, y = uint8(ConstantTimeByteEq(x, y)), x | 
|  | } | 
|  |  | 
|  | benchmarkGlobal = x | 
|  | } | 
|  |  | 
|  | func BenchmarkConstantTimeEq(b *testing.B) { | 
|  | var x, y int | 
|  |  | 
|  | for i := 0; i < b.N; i++ { | 
|  | x, y = ConstantTimeEq(int32(x), int32(y)), x | 
|  | } | 
|  |  | 
|  | benchmarkGlobal = uint8(x) | 
|  | } | 
|  |  | 
|  | func BenchmarkConstantTimeLessOrEq(b *testing.B) { | 
|  | var x, y int | 
|  |  | 
|  | for i := 0; i < b.N; i++ { | 
|  | x, y = ConstantTimeLessOrEq(x, y), x | 
|  | } | 
|  |  | 
|  | benchmarkGlobal = uint8(x) | 
|  | } |