quic: send and receive UDP datagrams
Add the Listener type, which manages a UDP socket.
For golang/go#58547
Change-Id: Ia23a8b726ef46f8f84c9e052aa4dfc10eab034d6
Reviewed-on: https://go-review.googlesource.com/c/net/+/527758
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go
index d479cd4..c528958 100644
--- a/internal/quic/conn_id_test.go
+++ b/internal/quic/conn_id_test.go
@@ -11,100 +11,135 @@
"crypto/tls"
"fmt"
"net/netip"
- "reflect"
+ "strings"
"testing"
)
func TestConnIDClientHandshake(t *testing.T) {
+ tc := newTestConn(t, clientSide)
// On initialization, the client chooses local and remote IDs.
//
// The order in which we allocate the two isn't actually important,
// but test is a lot simpler if we assume.
- var s connIDState
- s.initClient(newConnIDSequence())
- if got, want := string(s.srcConnID()), "local-1"; got != want {
- t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
+ if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) {
+ t.Errorf("after initialization: srcConnID = %x, want %x", got, want)
}
- dstConnID, _ := s.dstConnID()
- if got, want := string(dstConnID), "local-2"; got != want {
- t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
+ dstConnID, _ := tc.conn.connIDState.dstConnID()
+ if got, want := dstConnID, testLocalConnID(-1); !bytes.Equal(got, want) {
+ t.Errorf("after initialization: dstConnID = %x, want %x", got, want)
}
// The server's first Initial packet provides the client with a
// non-transient remote connection ID.
- s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
- dstConnID, _ = s.dstConnID()
- if got, want := string(dstConnID), "remote-1"; got != want {
- t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ dstConnID, _ = tc.conn.connIDState.dstConnID()
+ if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) {
+ t.Errorf("after receiving Initial: dstConnID = %x, want %x", got, want)
}
wantLocal := []connID{{
- cid: []byte("local-1"),
+ cid: testLocalConnID(0),
seq: 0,
}}
- if !reflect.DeepEqual(s.local, wantLocal) {
- t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+ if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+ t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
wantRemote := []connID{{
- cid: []byte("remote-1"),
+ cid: testPeerConnID(0),
seq: 0,
}}
- if !reflect.DeepEqual(s.remote, wantRemote) {
- t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+ if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
}
}
func TestConnIDServerHandshake(t *testing.T) {
+ tc := newTestConn(t, serverSide)
// On initialization, the server is provided with the client-chosen
// transient connection ID, and allocates an ID of its own.
// The Initial packet sets the remote connection ID.
- var s connIDState
- s.initServer(newConnIDSequence(), []byte("transient"))
- s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
- if got, want := string(s.srcConnID()), "local-1"; got != want {
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][:1],
+ })
+ if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) {
t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
}
- dstConnID, _ := s.dstConnID()
- if got, want := string(dstConnID), "remote-1"; got != want {
+ dstConnID, _ := tc.conn.connIDState.dstConnID()
+ if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) {
t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
}
+ // The Initial flight of CRYPTO data includes transport parameters,
+ // which cause us to allocate another local connection ID.
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ off: 1,
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][1:],
+ })
wantLocal := []connID{{
- cid: []byte("transient"),
+ cid: testPeerConnID(-1),
seq: -1,
}, {
- cid: []byte("local-1"),
+ cid: testLocalConnID(0),
seq: 0,
+ }, {
+ cid: testLocalConnID(1),
+ seq: 1,
}}
- if !reflect.DeepEqual(s.local, wantLocal) {
- t.Errorf("local ids: %v, want %v", s.local, wantLocal)
+ if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+ t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
wantRemote := []connID{{
- cid: []byte("remote-1"),
+ cid: testPeerConnID(0),
seq: 0,
}}
- if !reflect.DeepEqual(s.remote, wantRemote) {
- t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
+ if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) {
+ t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote))
}
// The client's first Handshake packet permits the server to discard the
// transient connection ID.
- s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
wantLocal = []connID{{
- cid: []byte("local-1"),
+ cid: testLocalConnID(0),
seq: 0,
+ }, {
+ cid: testLocalConnID(1),
+ seq: 1,
}}
- if !reflect.DeepEqual(s.local, wantLocal) {
- t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
+ if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) {
+ t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
}
-func newConnIDSequence() newConnIDFunc {
- var n uint64
- return func(_ int64) ([]byte, error) {
- n++
- return []byte(fmt.Sprintf("local-%v", n)), nil
+func connIDListEqual(a, b []connID) bool {
+ if len(a) != len(b) {
+ return false
}
+ for i := range a {
+ if a[i].seq != b[i].seq {
+ return false
+ }
+ if !bytes.Equal(a[i].cid, b[i].cid) {
+ return false
+ }
+ }
+ return true
+}
+
+func fmtConnIDList(s []connID) string {
+ var strs []string
+ for _, cid := range s {
+ strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x}]", cid.seq, cid.cid))
+ }
+ return "{" + strings.Join(strs, " ") + "}"
}
func TestNewRandomConnID(t *testing.T) {