blob: 4246ddd35b00b4a7c10f9cb25938fdc693cb4e8a [file] [log] [blame]
Adam Langley124e52d2012-03-12 10:59:04 -04001// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package curve25519
6
7import (
Michael McLoughlin193df9c2019-02-05 11:42:15 -08008 "bytes"
9 "crypto/rand"
Adam Langley124e52d2012-03-12 10:59:04 -040010 "fmt"
11 "testing"
12)
13
14const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a"
15
Filippo Valsordaf4817d92019-11-07 16:07:46 -050016func TestX25519Basepoint(t *testing.T) {
17 x := make([]byte, 32)
18 x[0] = 1
Adam Langley124e52d2012-03-12 10:59:04 -040019
20 for i := 0; i < 200; i++ {
Filippo Valsordaf4817d92019-11-07 16:07:46 -050021 var err error
22 x, err = X25519(x, Basepoint)
23 if err != nil {
24 t.Fatal(err)
25 }
Adam Langley124e52d2012-03-12 10:59:04 -040026 }
27
Filippo Valsordaf4817d92019-11-07 16:07:46 -050028 result := fmt.Sprintf("%x", x)
Adam Langley124e52d2012-03-12 10:59:04 -040029 if result != expectedHex {
30 t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
31 }
32}
Andreas Auernhammercd115412017-04-06 15:01:18 +020033
Filippo Valsordaf4817d92019-11-07 16:07:46 -050034func TestLowOrderPoints(t *testing.T) {
35 scalar := make([]byte, ScalarSize)
36 if _, err := rand.Read(scalar); err != nil {
37 t.Fatal(err)
38 }
39 for i, p := range lowOrderPoints {
40 out, err := X25519(scalar, p)
41 if err == nil {
42 t.Errorf("%d: expected error, got nil", i)
43 }
44 if out != nil {
45 t.Errorf("%d: expected nil output, got %x", i, out)
46 }
47 }
48}
49
Michael McLoughlina1f597e2019-02-05 13:23:42 -080050func TestTestVectors(t *testing.T) {
Filippo Valsorda3497b512021-05-05 17:15:45 -040051 t.Run("Legacy", func(t *testing.T) { testTestVectors(t, ScalarMult) })
Filippo Valsordaf4817d92019-11-07 16:07:46 -050052 t.Run("X25519", func(t *testing.T) {
53 testTestVectors(t, func(dst, scalar, point *[32]byte) {
54 out, err := X25519(scalar[:], point[:])
55 if err != nil {
56 t.Fatal(err)
57 }
58 copy(dst[:], out)
59 })
60 })
61}
62
63func testTestVectors(t *testing.T, scalarMult func(dst, scalar, point *[32]byte)) {
Michael McLoughlina1f597e2019-02-05 13:23:42 -080064 for _, tv := range testVectors {
65 var got [32]byte
Filippo Valsordaf4817d92019-11-07 16:07:46 -050066 scalarMult(&got, &tv.In, &tv.Base)
Michael McLoughlina1f597e2019-02-05 13:23:42 -080067 if !bytes.Equal(got[:], tv.Expect[:]) {
68 t.Logf(" in = %x", tv.In)
69 t.Logf(" base = %x", tv.Base)
70 t.Logf(" got = %x", got)
71 t.Logf("expect = %x", tv.Expect)
72 t.Fail()
73 }
74 }
75}
76
Michael McLoughlin193df9c2019-02-05 11:42:15 -080077// TestHighBitIgnored tests the following requirement in RFC 7748:
78//
79// When receiving such an array, implementations of X25519 (but not X448) MUST
80// mask the most significant bit in the final byte.
81//
82// Regression test for issue #30095.
83func TestHighBitIgnored(t *testing.T) {
84 var s, u [32]byte
85 rand.Read(s[:])
86 rand.Read(u[:])
87
88 var hi0, hi1 [32]byte
89
90 u[31] &= 0x7f
91 ScalarMult(&hi0, &s, &u)
92
93 u[31] |= 0x80
94 ScalarMult(&hi1, &s, &u)
95
96 if !bytes.Equal(hi0[:], hi1[:]) {
97 t.Errorf("high bit of group point should not affect result")
98 }
99}
100
Filippo Valsorda3497b512021-05-05 17:15:45 -0400101var benchmarkSink byte
Andreas Auernhammercd115412017-04-06 15:01:18 +0200102
Filippo Valsorda3497b512021-05-05 17:15:45 -0400103func BenchmarkX25519Basepoint(b *testing.B) {
104 scalar := make([]byte, ScalarSize)
105 if _, err := rand.Read(scalar); err != nil {
106 b.Fatal(err)
107 }
108
109 b.ResetTimer()
Andreas Auernhammercd115412017-04-06 15:01:18 +0200110 for i := 0; i < b.N; i++ {
Filippo Valsorda3497b512021-05-05 17:15:45 -0400111 out, err := X25519(scalar, Basepoint)
112 if err != nil {
113 b.Fatal(err)
114 }
115 benchmarkSink ^= out[0]
116 }
117}
118
119func BenchmarkX25519(b *testing.B) {
120 scalar := make([]byte, ScalarSize)
121 if _, err := rand.Read(scalar); err != nil {
122 b.Fatal(err)
123 }
124 point, err := X25519(scalar, Basepoint)
125 if err != nil {
126 b.Fatal(err)
127 }
128 if _, err := rand.Read(scalar); err != nil {
129 b.Fatal(err)
130 }
131
132 b.ResetTimer()
133 for i := 0; i < b.N; i++ {
134 out, err := X25519(scalar, point)
135 if err != nil {
136 b.Fatal(err)
137 }
138 benchmarkSink ^= out[0]
Andreas Auernhammercd115412017-04-06 15:01:18 +0200139 }
140}