blob: 4d8516ae6c88f7f3ac907b3727330b4b4abe6ff2 [file] [log] [blame] [edit]
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"bytes"
"encoding/binary"
"testing"
)
func TestDecodePacketNumber(t *testing.T) {
for _, test := range []struct {
largest packetNumber
truncated packetNumber
want packetNumber
size int
}{{
largest: 0,
truncated: 1,
size: 4,
want: 1,
}, {
largest: 0,
truncated: 0,
size: 1,
want: 0,
}, {
largest: 0x00,
truncated: 0x01,
size: 1,
want: 0x01,
}, {
largest: 0x00,
truncated: 0xff,
size: 1,
want: 0xff,
}, {
largest: 0xff,
truncated: 0x01,
size: 1,
want: 0x101,
}, {
largest: 0x1000,
truncated: 0xff,
size: 1,
want: 0xfff,
}, {
largest: 0xa82f30ea,
truncated: 0x9b32,
size: 2,
want: 0xa82f9b32,
}} {
got := decodePacketNumber(test.largest, test.truncated, test.size)
if got != test.want {
t.Errorf("decodePacketNumber(largest=0x%x, truncated=0x%x, size=%v) = 0x%x, want 0x%x", test.largest, test.truncated, test.size, got, test.want)
}
}
}
func TestEncodePacketNumber(t *testing.T) {
for _, test := range []struct {
largestAcked packetNumber
pnum packetNumber
wantSize int
}{{
largestAcked: -1,
pnum: 0,
wantSize: 1,
}, {
largestAcked: 1000,
pnum: 1000 + 0x7f,
wantSize: 1,
}, {
largestAcked: 1000,
pnum: 1000 + 0x80, // 0x468
wantSize: 2,
}, {
largestAcked: 0x12345678,
pnum: 0x12345678 + 0x7fff, // 0x305452663
wantSize: 2,
}, {
largestAcked: 0x12345678,
pnum: 0x12345678 + 0x8000,
wantSize: 3,
}, {
largestAcked: 0,
pnum: 0x7fffff,
wantSize: 3,
}, {
largestAcked: 0,
pnum: 0x800000,
wantSize: 4,
}, {
largestAcked: 0xabe8bc,
pnum: 0xac5c02,
wantSize: 2,
}, {
largestAcked: 0xabe8bc,
pnum: 0xace8fe,
wantSize: 3,
}} {
size := packetNumberLength(test.pnum, test.largestAcked)
if got, want := size, test.wantSize; got != want {
t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, want %v", test.pnum, test.largestAcked, got, want)
}
var enc packetNumber
switch size {
case 1:
enc = test.pnum & 0xff
case 2:
enc = test.pnum & 0xffff
case 3:
enc = test.pnum & 0xffffff
case 4:
enc = test.pnum & 0xffffffff
}
wantBytes := binary.BigEndian.AppendUint32(nil, uint32(enc))[4-size:]
gotBytes := appendPacketNumber(nil, test.pnum, test.largestAcked)
if !bytes.Equal(gotBytes, wantBytes) {
t.Errorf("appendPacketNumber(num=%v, maxAck=%x) = {%x}, want {%x}", test.pnum, test.largestAcked, gotBytes, wantBytes)
}
gotNum := decodePacketNumber(test.largestAcked, enc, size)
if got, want := gotNum, test.pnum; got != want {
t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, but decoded number=%x", test.pnum, test.largestAcked, size, got)
}
}
}
func FuzzPacketNumber(f *testing.F) {
truncatedNumber := func(in []byte) packetNumber {
var truncated packetNumber
for _, b := range in {
truncated = (truncated << 8) | packetNumber(b)
}
return truncated
}
f.Fuzz(func(t *testing.T, in []byte, largestAckedInt64 int64) {
largestAcked := packetNumber(largestAckedInt64)
if len(in) < 1 || len(in) > 4 || largestAcked < 0 || largestAcked > maxPacketNumber {
return
}
truncatedIn := truncatedNumber(in)
decoded := decodePacketNumber(largestAcked, truncatedIn, len(in))
// Check that the decoded packet number's least significant bits match the input.
var mask packetNumber
for i := 0; i < len(in); i++ {
mask = (mask << 8) | 0xff
}
if truncatedIn != decoded&mask {
t.Fatalf("decoding mismatch: input=%x largestAcked=%v decoded=0x%x", in, largestAcked, decoded)
}
// We don't support encoding packet numbers less than largestAcked (since packet numbers
// never decrease), so skip the encoder tests if this would make us go backwards.
if decoded < largestAcked {
return
}
// We might encode this number using a different length than we received,
// but the common portions should match.
encoded := appendPacketNumber(nil, decoded, largestAcked)
a, b := in, encoded
if len(b) < len(a) {
a, b = b, a
}
for len(a) < len(b) {
b = b[1:]
}
if len(a) == 0 || !bytes.Equal(a, b) {
t.Fatalf("encoding mismatch: input=%x largestAcked=%v decoded=%v reencoded=%x", in, largestAcked, decoded, encoded)
}
if g := decodePacketNumber(largestAcked, truncatedNumber(encoded), len(encoded)); g != decoded {
t.Fatalf("packet encode/decode mismatch: pnum=%v largestAcked=%v encoded=%x got=%v", decoded, largestAcked, encoded, g)
}
if l := packetNumberLength(decoded, largestAcked); l != len(encoded) {
t.Fatalf("packet number length mismatch: pnum=%v largestAcked=%v encoded=%x len=%v", decoded, largestAcked, encoded, l)
}
})
}