blob: 7c31e9d560ff0f5bae0edda5e9189d7cae9b0616 [file] [log] [blame]
Damien Neil57553cb2022-10-13 12:09:20 -07001// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build go1.21
6
7package quic
8
9import (
10 "fmt"
11 "reflect"
12 "testing"
13)
14
15func TestConnIDClientHandshake(t *testing.T) {
16 // On initialization, the client chooses local and remote IDs.
17 //
18 // The order in which we allocate the two isn't actually important,
19 // but test is a lot simpler if we assume.
20 var s connIDState
21 s.initClient(newConnIDSequence())
22 if got, want := string(s.srcConnID()), "local-1"; got != want {
23 t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
24 }
25 if got, want := string(s.dstConnID()), "local-2"; got != want {
26 t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
27 }
28
29 // The server's first Initial packet provides the client with a
30 // non-transient remote connection ID.
31 s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
32 if got, want := string(s.dstConnID()), "remote-1"; got != want {
33 t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
34 }
35
36 wantLocal := []connID{{
37 cid: []byte("local-1"),
38 seq: 0,
39 }}
40 if !reflect.DeepEqual(s.local, wantLocal) {
41 t.Errorf("local ids: %v, want %v", s.local, wantLocal)
42 }
43 wantRemote := []connID{{
44 cid: []byte("remote-1"),
45 seq: 0,
46 }}
47 if !reflect.DeepEqual(s.remote, wantRemote) {
48 t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
49 }
50}
51
52func TestConnIDServerHandshake(t *testing.T) {
53 // On initialization, the server is provided with the client-chosen
54 // transient connection ID, and allocates an ID of its own.
55 // The Initial packet sets the remote connection ID.
56 var s connIDState
57 s.initServer(newConnIDSequence(), []byte("transient"))
58 s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
59 if got, want := string(s.srcConnID()), "local-1"; got != want {
60 t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
61 }
62 if got, want := string(s.dstConnID()), "remote-1"; got != want {
63 t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
64 }
65
66 wantLocal := []connID{{
67 cid: []byte("transient"),
68 seq: -1,
69 }, {
70 cid: []byte("local-1"),
71 seq: 0,
72 }}
73 if !reflect.DeepEqual(s.local, wantLocal) {
74 t.Errorf("local ids: %v, want %v", s.local, wantLocal)
75 }
76 wantRemote := []connID{{
77 cid: []byte("remote-1"),
78 seq: 0,
79 }}
80 if !reflect.DeepEqual(s.remote, wantRemote) {
81 t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
82 }
83
84 // The client's first Handshake packet permits the server to discard the
85 // transient connection ID.
86 s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
87 wantLocal = []connID{{
88 cid: []byte("local-1"),
89 seq: 0,
90 }}
91 if !reflect.DeepEqual(s.local, wantLocal) {
92 t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
93 }
94}
95
96func newConnIDSequence() newConnIDFunc {
97 var n uint64
98 return func() ([]byte, error) {
99 n++
100 return []byte(fmt.Sprintf("local-%v", n)), nil
101 }
102}
103
104func TestNewRandomConnID(t *testing.T) {
105 cid, err := newRandomConnID()
106 if len(cid) != connIDLen || err != nil {
107 t.Fatalf("newConnID() = %x, %v; want %v bytes", cid, connIDLen, err)
108 }
109}