| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| //go:build go1.24 && goexperiment.synctest |
| |
| package http3 |
| |
| import ( |
| "bytes" |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "maps" |
| "net/http" |
| "reflect" |
| "slices" |
| "testing" |
| "testing/synctest" |
| |
| "golang.org/x/net/internal/quic/quicwire" |
| "golang.org/x/net/quic" |
| ) |
| |
| func TestTransportServerCreatesBidirectionalStream(t *testing.T) { |
| // "Clients MUST treat receipt of a server-initiated bidirectional |
| // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]" |
| // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3 |
| runSynctest(t, func(t testing.TB) { |
| tc := newTestClientConn(t) |
| tc.greet() |
| st := tc.newStream(streamTypeRequest) |
| st.Flush() |
| tc.wantClosed("after server creates bidi stream", errH3StreamCreationError) |
| }) |
| } |
| |
| // A testQUICConn wraps a *quic.Conn and provides methods for inspecting it. |
| type testQUICConn struct { |
| t testing.TB |
| qconn *quic.Conn |
| streams map[streamType][]*testQUICStream |
| } |
| |
| func newTestQUICConn(t testing.TB, qconn *quic.Conn) *testQUICConn { |
| tq := &testQUICConn{ |
| t: t, |
| qconn: qconn, |
| streams: make(map[streamType][]*testQUICStream), |
| } |
| |
| go tq.acceptStreams(t.Context()) |
| |
| t.Cleanup(func() { |
| tq.qconn.Close() |
| }) |
| return tq |
| } |
| |
| func (tq *testQUICConn) acceptStreams(ctx context.Context) { |
| for { |
| qst, err := tq.qconn.AcceptStream(ctx) |
| if err != nil { |
| return |
| } |
| st := newStream(qst) |
| stype := streamTypeRequest |
| if qst.IsReadOnly() { |
| v, err := st.readVarint() |
| if err != nil { |
| tq.t.Errorf("error reading stream type from unidirectional stream: %v", err) |
| continue |
| } |
| stype = streamType(v) |
| } |
| tq.streams[stype] = append(tq.streams[stype], newTestQUICStream(tq.t, st)) |
| } |
| } |
| |
| func (tq *testQUICConn) newStream(stype streamType) *testQUICStream { |
| tq.t.Helper() |
| var qs *quic.Stream |
| var err error |
| if stype == streamTypeRequest { |
| qs, err = tq.qconn.NewStream(canceledCtx) |
| } else { |
| qs, err = tq.qconn.NewSendOnlyStream(canceledCtx) |
| } |
| if err != nil { |
| tq.t.Fatal(err) |
| } |
| st := newStream(qs) |
| if stype != streamTypeRequest { |
| st.writeVarint(int64(stype)) |
| if err := st.Flush(); err != nil { |
| tq.t.Fatal(err) |
| } |
| } |
| return newTestQUICStream(tq.t, st) |
| } |
| |
| // wantNotClosed asserts that the peer has not closed the connectioln. |
| func (tq *testQUICConn) wantNotClosed(reason string) { |
| t := tq.t |
| t.Helper() |
| synctest.Wait() |
| err := tq.qconn.Wait(canceledCtx) |
| if !errors.Is(err, context.Canceled) { |
| t.Fatalf("%v: want QUIC connection to be alive; closed with error: %v", reason, err) |
| } |
| } |
| |
| // wantClosed asserts that the peer has closed the connection |
| // with the provided error code. |
| func (tq *testQUICConn) wantClosed(reason string, want error) { |
| t := tq.t |
| t.Helper() |
| synctest.Wait() |
| |
| if e, ok := want.(http3Error); ok { |
| want = &quic.ApplicationError{Code: uint64(e)} |
| } |
| got := tq.qconn.Wait(canceledCtx) |
| if errors.Is(got, context.Canceled) { |
| t.Fatalf("%v: want QUIC connection closed, but it is not", reason) |
| } |
| if !errors.Is(got, want) { |
| t.Fatalf("%v: connection closed with error: %v; want %v", reason, got, want) |
| } |
| } |
| |
| // wantStream asserts that a stream of a given type has been created, |
| // and returns that stream. |
| func (tq *testQUICConn) wantStream(stype streamType) *testQUICStream { |
| tq.t.Helper() |
| synctest.Wait() |
| if len(tq.streams[stype]) == 0 { |
| tq.t.Fatalf("expected a %v stream to be created, but none were", stype) |
| } |
| ts := tq.streams[stype][0] |
| tq.streams[stype] = tq.streams[stype][1:] |
| return ts |
| } |
| |
| // testQUICStream wraps a QUIC stream and provides methods for inspecting it. |
| type testQUICStream struct { |
| t testing.TB |
| *stream |
| } |
| |
| func newTestQUICStream(t testing.TB, st *stream) *testQUICStream { |
| st.stream.SetReadContext(canceledCtx) |
| st.stream.SetWriteContext(canceledCtx) |
| return &testQUICStream{ |
| t: t, |
| stream: st, |
| } |
| } |
| |
| // wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type. |
| func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) { |
| ts.t.Helper() |
| synctest.Wait() |
| gotType, err := ts.readFrameHeader() |
| if err != nil { |
| ts.t.Fatalf("%v: failed to read frame header: %v", reason, err) |
| } |
| if gotType != wantType { |
| ts.t.Fatalf("%v: got frame type %v, want %v", reason, gotType, wantType) |
| } |
| } |
| |
| // wantHeaders reads a HEADERS frame. |
| // If want is nil, the contents of the frame are ignored. |
| func (ts *testQUICStream) wantHeaders(want http.Header) { |
| ts.t.Helper() |
| ftype, err := ts.readFrameHeader() |
| if err != nil { |
| ts.t.Fatalf("want HEADERS frame, got error: %v", err) |
| } |
| if ftype != frameTypeHeaders { |
| ts.t.Fatalf("want HEADERS frame, got: %v", ftype) |
| } |
| |
| if want == nil { |
| if err := ts.discardFrame(); err != nil { |
| ts.t.Fatalf("discardFrame: %v", err) |
| } |
| return |
| } |
| |
| got := make(http.Header) |
| var dec qpackDecoder |
| err = dec.decode(ts.stream, func(_ indexType, name, value string) error { |
| got.Add(name, value) |
| return nil |
| }) |
| if diff := diffHeaders(got, want); diff != "" { |
| ts.t.Fatalf("unexpected response headers:\n%v", diff) |
| } |
| if err := ts.endFrame(); err != nil { |
| ts.t.Fatalf("endFrame: %v", err) |
| } |
| } |
| |
| func (ts *testQUICStream) encodeHeaders(h http.Header) []byte { |
| ts.t.Helper() |
| var enc qpackEncoder |
| return enc.encode(func(yield func(itype indexType, name, value string)) { |
| names := slices.Collect(maps.Keys(h)) |
| slices.Sort(names) |
| for _, k := range names { |
| for _, v := range h[k] { |
| yield(mayIndex, k, v) |
| } |
| } |
| }) |
| } |
| |
| func (ts *testQUICStream) writeHeaders(h http.Header) { |
| ts.t.Helper() |
| headers := ts.encodeHeaders(h) |
| ts.writeVarint(int64(frameTypeHeaders)) |
| ts.writeVarint(int64(len(headers))) |
| ts.Write(headers) |
| if err := ts.Flush(); err != nil { |
| ts.t.Fatalf("flushing HEADERS frame: %v", err) |
| } |
| } |
| |
| func (ts *testQUICStream) wantData(want []byte) { |
| ts.t.Helper() |
| synctest.Wait() |
| ftype, err := ts.readFrameHeader() |
| if err != nil { |
| ts.t.Fatalf("want DATA frame, got error: %v", err) |
| } |
| if ftype != frameTypeData { |
| ts.t.Fatalf("want DATA frame, got: %v", ftype) |
| } |
| got, err := ts.readFrameData() |
| if err != nil { |
| ts.t.Fatalf("error reading DATA frame: %v", err) |
| } |
| if !bytes.Equal(got, want) { |
| ts.t.Fatalf("got data: {%x}, want {%x}", got, want) |
| } |
| if err := ts.endFrame(); err != nil { |
| ts.t.Fatalf("endFrame: %v", err) |
| } |
| } |
| |
| func (ts *testQUICStream) wantClosed(reason string) { |
| ts.t.Helper() |
| synctest.Wait() |
| ftype, err := ts.readFrameHeader() |
| if err != io.EOF { |
| ts.t.Fatalf("%v: want io.EOF, got %v %v", reason, ftype, err) |
| } |
| } |
| |
| func (ts *testQUICStream) wantError(want quic.StreamErrorCode) { |
| ts.t.Helper() |
| synctest.Wait() |
| _, err := ts.stream.stream.ReadByte() |
| if err == nil { |
| ts.t.Fatalf("successfully read from stream; want stream error code %v", want) |
| } |
| var got quic.StreamErrorCode |
| if !errors.As(err, &got) { |
| ts.t.Fatalf("stream error = %v; want %v", err, want) |
| } |
| if got != want { |
| ts.t.Fatalf("stream error code = %v; want %v", got, want) |
| } |
| } |
| |
| func (ts *testQUICStream) writePushPromise(pushID int64, h http.Header) { |
| ts.t.Helper() |
| headers := ts.encodeHeaders(h) |
| ts.writeVarint(int64(frameTypePushPromise)) |
| ts.writeVarint(int64(quicwire.SizeVarint(uint64(pushID)) + len(headers))) |
| ts.writeVarint(pushID) |
| ts.Write(headers) |
| if err := ts.Flush(); err != nil { |
| ts.t.Fatalf("flushing PUSH_PROMISE frame: %v", err) |
| } |
| } |
| |
| func diffHeaders(got, want http.Header) string { |
| // nil and 0-length non-nil are equal. |
| if len(got) == 0 && len(want) == 0 { |
| return "" |
| } |
| // We could do a more sophisticated diff here. |
| // DeepEqual is good enough for now. |
| if reflect.DeepEqual(got, want) { |
| return "" |
| } |
| return fmt.Sprintf("got: %v\nwant: %v", got, want) |
| } |
| |
| func (ts *testQUICStream) Flush() error { |
| err := ts.stream.Flush() |
| ts.t.Helper() |
| if err != nil { |
| ts.t.Errorf("unexpected error flushing stream: %v", err) |
| } |
| return err |
| } |
| |
| // A testClientConn is a ClientConn on a test network. |
| type testClientConn struct { |
| tr *Transport |
| cc *ClientConn |
| |
| // *testQUICConn is the server half of the connection. |
| *testQUICConn |
| control *testQUICStream |
| } |
| |
| func newTestClientConn(t testing.TB) *testClientConn { |
| e1, e2 := newQUICEndpointPair(t) |
| tr := &Transport{ |
| Endpoint: e1, |
| Config: &quic.Config{ |
| TLSConfig: testTLSConfig, |
| }, |
| } |
| |
| cc, err := tr.Dial(t.Context(), e2.LocalAddr().String()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| t.Cleanup(func() { |
| cc.Close() |
| }) |
| srvConn, err := e2.Accept(t.Context()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| tc := &testClientConn{ |
| tr: tr, |
| cc: cc, |
| testQUICConn: newTestQUICConn(t, srvConn), |
| } |
| synctest.Wait() |
| return tc |
| } |
| |
| // greet performs initial connection handshaking with the client. |
| func (tc *testClientConn) greet() { |
| // Client creates a control stream. |
| clientControlStream := tc.wantStream(streamTypeControl) |
| clientControlStream.wantFrameHeader( |
| "client sends SETTINGS frame on control stream", |
| frameTypeSettings) |
| clientControlStream.discardFrame() |
| |
| // Server creates a control stream. |
| tc.control = tc.newStream(streamTypeControl) |
| tc.control.writeVarint(int64(frameTypeSettings)) |
| tc.control.writeVarint(0) // size |
| tc.control.Flush() |
| |
| synctest.Wait() |
| } |
| |
| type testRoundTrip struct { |
| t testing.TB |
| resp *http.Response |
| respErr error |
| } |
| |
| func (rt *testRoundTrip) done() bool { |
| synctest.Wait() |
| return rt.resp != nil || rt.respErr != nil |
| } |
| |
| func (rt *testRoundTrip) result() (*http.Response, error) { |
| rt.t.Helper() |
| if !rt.done() { |
| rt.t.Fatal("RoundTrip is not done; want it to be") |
| } |
| return rt.resp, rt.respErr |
| } |
| |
| func (rt *testRoundTrip) response() *http.Response { |
| rt.t.Helper() |
| if !rt.done() { |
| rt.t.Fatal("RoundTrip is not done; want it to be") |
| } |
| if rt.respErr != nil { |
| rt.t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr) |
| } |
| return rt.resp |
| } |
| |
| // err returns the (possibly nil) error result of RoundTrip. |
| func (rt *testRoundTrip) err() error { |
| rt.t.Helper() |
| _, err := rt.result() |
| return err |
| } |
| |
| func (rt *testRoundTrip) wantError(reason string) { |
| rt.t.Helper() |
| synctest.Wait() |
| if !rt.done() { |
| rt.t.Fatalf("%v: RoundTrip is not done; want it to have returned an error", reason) |
| } |
| if rt.respErr == nil { |
| rt.t.Fatalf("%v: RoundTrip succeeded; want it to have returned an error", reason) |
| } |
| } |
| |
| // wantStatus indicates the expected response StatusCode. |
| func (rt *testRoundTrip) wantStatus(want int) { |
| rt.t.Helper() |
| if got := rt.response().StatusCode; got != want { |
| rt.t.Fatalf("got response status %v, want %v", got, want) |
| } |
| } |
| |
| func (rt *testRoundTrip) wantHeaders(want http.Header) { |
| rt.t.Helper() |
| if diff := diffHeaders(rt.response().Header, want); diff != "" { |
| rt.t.Fatalf("unexpected response headers:\n%v", diff) |
| } |
| } |
| |
| func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { |
| rt := &testRoundTrip{t: tc.t} |
| go func() { |
| rt.resp, rt.respErr = tc.cc.RoundTrip(req) |
| }() |
| return rt |
| } |
| |
| // canceledCtx is a canceled Context. |
| // Used for performing non-blocking QUIC operations. |
| var canceledCtx = func() context.Context { |
| ctx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| return ctx |
| }() |