diff --git a/go.mod b/go.mod index f43f66a..c435e76 100644 --- a/go.mod +++ b/go.mod
@@ -3,8 +3,8 @@ go 1.25.0 require ( - golang.org/x/crypto v0.48.0 - golang.org/x/sys v0.41.0 - golang.org/x/term v0.40.0 - golang.org/x/text v0.34.0 + golang.org/x/crypto v0.50.0 + golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 + golang.org/x/text v0.36.0 )
diff --git a/go.sum b/go.sum index 6fae034..27d5cd0 100644 --- a/go.sum +++ b/go.sum
@@ -1,8 +1,8 @@ -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index d89c257..5ab499b 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go
@@ -3,8 +3,8 @@ // license that can be found in the LICENSE file. // Package httpproxy provides support for HTTP proxy determination -// based on environment variables, as provided by net/http's -// ProxyFromEnvironment function. +// based on environment variables, as provided by +// [net/http.ProxyFromEnvironment] function. // // The API is not subject to the Go 1 compatibility promise and may change at // any time. @@ -56,7 +56,7 @@ // presence of a REQUEST_METHOD environment variable). // When this is set, ProxyForURL will return an error // when HTTPProxy applies, because a client could be - // setting HTTP_PROXY maliciously. See https://golang.org/s/cgihttpproxy. + // setting HTTP_PROXY maliciously. See https://go.dev/s/cgihttpproxy. CGI bool } @@ -113,7 +113,7 @@ // environment, or a proxy should not be used for the given request, as // defined by NO_PROXY. // -// As a special case, if req.URL.Host is "localhost" or a loopback address +// As a special case, if reqURL.Host is "localhost" or a loopback address // (with or without a port number), then a nil URL and nil error will be returned. func (cfg *Config) ProxyFunc() func(reqURL *url.URL) (*url.URL, error) { // Preprocess the Config settings for more efficient evaluation.
diff --git a/http2/clientconn_go125_test.go b/http2/clientconn_go125_test.go new file mode 100644 index 0000000..a4027b7 --- /dev/null +++ b/http2/clientconn_go125_test.go
@@ -0,0 +1,37 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.26 + +package http2_test + +import ( + "net/http" + "testing" +) + +type httpClientConn struct{} // http.ClientConn was added in Go 1.26 + +func (httpClientConn) RoundTrip(*http.Request) (*http.Response, error) { + panic("should never be called") +} + +func newTestClientConn(t testing.TB, opts ...any) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + + if tt.mode == roundTripNetHTTP { + t.Skip("roundTripNetHTTP not supported go <1.26: no NewClientConn") + } + + nc := tt.li.newConn() + const singleUse = false + _, err := tt.tr.TestNewClientConn(nc, singleUse, nil) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() +}
diff --git a/http2/clientconn_go126_test.go b/http2/clientconn_go126_test.go new file mode 100644 index 0000000..23db771 --- /dev/null +++ b/http2/clientconn_go126_test.go
@@ -0,0 +1,43 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.26 + +package http2_test + +import ( + "net/http" + "testing" +) + +type httpClientConn = http.ClientConn + +func newTestClientConn(t testing.TB, opts ...any) *testClientConn { + t.Helper() + + tt := newTestTransport(t, opts...) + + switch tt.mode { + case roundTripNetHTTP: + cc, err := tt.tr1.NewClientConn(t.Context(), "http", "localhost:80") + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + tc := tt.getConn() + tc.cc1 = cc + return tc + case roundTripXNetHTTP2: + nc := tt.li.newConn() + const singleUse = false + _, err := tt.tr.TestNewClientConn(nc, singleUse, nil) + if err != nil { + t.Fatalf("newClientConn: %v", err) + } + + return tt.getConn() + default: + panic("unknown test mode") + } +}
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 10b1fc1..90aa794 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go
@@ -11,15 +11,20 @@ "bytes" "context" "crypto/tls" + "errors" "fmt" "io" + "net" "net/http" "reflect" + "slices" + "sync" "sync/atomic" "testing" "testing/synctest" "time" + "golang.org/x/net/http2" . "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "golang.org/x/net/internal/gate" @@ -84,6 +89,48 @@ rt.wantBody(nil) } +func TestTestTransport(t *testing.T) { + synctestSubtest(t, "nethttp", func(t testing.TB) { + testTestTransport(t, roundTripNetHTTP) + }) + synctestSubtest(t, "xnethttp2", func(t testing.TB) { + testTestTransport(t, roundTripXNetHTTP2) + }) +} +func testTestTransport(t testing.TB, mode roundTripTestMode) { + tt := newTestTransport(t) + + req := Must(http.NewRequest("GET", "https://dummy.tld/", nil)) + rt := tt.roundTrip(req) + tc := tt.getConn() + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/"}, + }, + }) + + tc.writeSettings() + tc.writeSettingsAck() + tc.wantFrameType(FrameSettings) // acknowledgement + tc.writeHeaders(HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + rt.wantStatus(200) +} + // A testClientConn allows testing ClientConn.RoundTrip against a fake server. // // A test using testClientConn consists of: @@ -96,59 +143,49 @@ type testClientConn struct { t testing.TB - tr *Transport - fr *Framer - cc *ClientConn + tr *Transport + fr *Framer + cc *ClientConn + cc1 *httpClientConn testConnFramer encbuf bytes.Buffer enc *hpack.Encoder - roundtrips []*testRoundTrip - - netconn *synctestNetConn + netconn *synctestNetConn + connReader *nonblockingReader } -func newTestClientConnFromClientConn(t testing.TB, tr *Transport, cc *ClientConn) *testClientConn { +func newTestClientConnFromNetConn(tt *testTransport, nc net.Conn) *testClientConn { tc := &testClientConn{ - t: t, - tr: tr, - cc: cc, + t: tt.t, + tr: tt.tr, } - // srv is the side controlled by the test. - var srv *synctestNetConn - if tconn := cc.TestNetConn(); tconn == nil { - // If cc.tconn is nil, we're being called with a new conn created by the - // Transport's client pool. This path skips dialing the server, and we - // create a test connection pair here. - var cli *synctestNetConn - cli, srv = synctestNetPipe() - cc.TestSetNetConn(cli) + var writer io.Writer + var reader io.Reader + if tt.useTLS { + tlsConfig := testTLSServerConfig.Clone() + tlsConfig.NextProtos = []string{"h2"} + tlsConn := tls.Server(nc, tlsConfig) + reader = tlsConn + writer = tlsConn } else { - // If cc.tconn is non-nil, we're in a test which provides a conn to the - // Transport via a TLSNextProto hook. Extract the test connection pair. - if tc, ok := tconn.(*tls.Conn); ok { - // Unwrap any *tls.Conn to the underlying net.Conn, - // to avoid dealing with encryption in tests. - tconn = tc.NetConn() - cc.TestSetNetConn(tconn) - } - srv = tconn.(*synctestNetConn).peer + reader = nc + writer = nc } + tc.connReader = newNonblockingReader(reader) - srv.SetReadDeadline(time.Now()) - srv.autoWait = true - tc.netconn = srv + tc.netconn = nc.(*synctestNetConn) tc.enc = hpack.NewEncoder(&tc.encbuf) - tc.fr = NewFramer(srv, srv) + tc.fr = NewFramer(writer, tc.connReader) tc.testConnFramer = testConnFramer{ - t: t, + t: tt.t, fr: tc.fr, dec: hpack.NewDecoder(InitialHeaderTableSize, nil), } tc.fr.SetMaxReadFrameSize(10 << 20) - t.Cleanup(func() { + tt.t.Cleanup(func() { tc.closeWrite() }) @@ -158,8 +195,9 @@ func (tc *testClientConn) readClientPreface() { tc.t.Helper() // Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames. + synctest.Wait() buf := make([]byte, len(ClientPreface)) - if _, err := io.ReadFull(tc.netconn, buf); err != nil { + if _, err := io.ReadFull(tc.connReader, buf); err != nil { tc.t.Fatalf("reading preface: %v", err) } if !bytes.Equal(buf, []byte(ClientPreface)) { @@ -167,26 +205,15 @@ } } -func newTestClientConn(t testing.TB, opts ...any) *testClientConn { - t.Helper() - - tt := newTestTransport(t, opts...) - const singleUse = false - _, err := tt.tr.TestNewClientConn(nil, singleUse, nil) - if err != nil { - t.Fatalf("newClientConn: %v", err) - } - - return tt.getConn() -} - // hasFrame reports whether a frame is available to be read. func (tc *testClientConn) hasFrame() bool { - return len(tc.netconn.Peek()) > 0 + synctest.Wait() + return tc.connReader.buf.Len() > 0 } // isClosed reports whether the peer has closed the connection. func (tc *testClientConn) isClosed() bool { + synctest.Wait() return tc.netconn.IsClosedByPeer() } @@ -297,23 +324,41 @@ // (Note that the RoundTrip won't complete until response headers are received, // the request times out, or some other terminal condition is reached.) func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { - ctx, cancel := context.WithCancel(req.Context()) - req = req.WithContext(ctx) - rt := &testRoundTrip{ - t: tc.t, - donec: make(chan struct{}), - cancel: cancel, - } - tc.roundtrips = append(tc.roundtrips, rt) - go func() { - defer close(rt.donec) - rt.resp, rt.respErr = tc.cc.TestRoundTrip(req, func(streamID uint32) { + rt := &testRoundTrip{} + rt.do(tc.t, req, func(req *http.Request) (*http.Response, error) { + if tc.cc1 != nil { + return tc.cc1.RoundTrip(req) + } + return tc.cc.TestRoundTrip(req, func(streamID uint32) { rt.id.Store(streamID) }) + }) + return rt +} + +func newTestRoundTrip(t testing.TB, req *http.Request, f func(*http.Request) (*http.Response, error)) *testRoundTrip { + rt := &testRoundTrip{} + rt.do(t, req, f) + return rt +} + +func (rt *testRoundTrip) do(t testing.TB, req *http.Request, f func(*http.Request) (*http.Response, error)) { + if rt.t != nil { + t.Fatal("testRoundTrip can only be used once") + } + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + rt.t = t + rt.donec = make(chan struct{}) + rt.cancel = cancel + go func() { + defer close(rt.donec) + rt.resp, rt.respErr = f(req) }() synctest.Wait() - tc.t.Cleanup(func() { + t.Cleanup(func() { + rt.cancel() if !rt.done() { return } @@ -322,8 +367,6 @@ res.Body.Close() } }) - - return rt } func (tc *testClientConn) greet(settings ...Setting) { @@ -352,6 +395,7 @@ // inflowWindow returns the amount of inbound flow control available for a stream, // or for the connection if streamID is 0. func (tc *testClientConn) inflowWindow(streamID uint32) int32 { + synctest.Wait() w, err := tc.cc.TestInflowWindow(streamID) if err != nil { tc.t.Error(err) @@ -371,6 +415,7 @@ // streamID returns the HTTP/2 stream ID of the request. func (rt *testRoundTrip) streamID() uint32 { + synctest.Wait() id := rt.id.Load() if id == 0 { panic("stream ID unknown") @@ -380,6 +425,7 @@ // done reports whether RoundTrip has returned. func (rt *testRoundTrip) done() bool { + synctest.Wait() select { case <-rt.donec: return true @@ -487,22 +533,97 @@ return fmt.Sprintf("got: %v\nwant: %v", got, want) } +// roundTripTestMode selects which RoundTrip API a test uses. +type roundTripTestMode int + +const ( + // roundTripNetHTTP uses net/http.Transport.RoundTrip or + // net/http.ClientConn.RoundTrip: + // + // t1 := http.Transport{} + // t2 := ConfigureTransports(t1) + // resp, err := t1.RoundTrip(req) + // + roundTripNetHTTP = roundTripTestMode(iota) + + // roundTripXNetHTTP2 uses x/net/http2.Transport.RoundTrip or + // x/net/http2.ClientConn.RoundTrip: + // + // t2 := http2.Transport{} + // resp, err := t2.RoundTrip(req) + roundTripXNetHTTP2 +) + // A testTransport allows testing Transport.RoundTrip against fake servers. // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling // should use testClientConn instead. type testTransport struct { - t testing.TB - tr *Transport + t testing.TB + tr *Transport + tr1 *http.Transport + li *synctestNetListener + mode roundTripTestMode - ccs []*testClientConn + ccMu sync.Mutex + ccqueue []*testClientConn + ccs map[*synctestNetConn]*testClientConn + + ccpending []*testPendingClientConn + + useTLS bool +} + +type testPendingClientConn struct { + nc *synctestNetConn + cc *ClientConn + tc *testClientConn } func newTestTransport(t testing.TB, opts ...any) *testTransport { tt := &testTransport{ - t: t, + t: t, + li: newSynctestNetListener(), + ccs: make(map[*synctestNetConn]*testClientConn), + mode: roundTripXNetHTTP2, } - tr := &Transport{} + for _, o := range opts { + switch o := o.(type) { + case roundTripTestMode: + tt.mode = o + } + } + + var ( + tr *Transport + tr1 *http.Transport + ) + switch tt.mode { + case roundTripXNetHTTP2: + tr = &Transport{ + DialTLSContext: func(ctx context.Context, network, address string, tlsConf *tls.Config) (net.Conn, error) { + return tt.li.newConn(), nil + }, + AllowHTTP: true, + } + case roundTripNetHTTP: + tr1 = &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return tt.li.newConn(), nil + }, + Protocols: &http.Protocols{}, + TLSClientConfig: testTLSClientConfig, + } + tr1.Protocols.SetHTTP2(true) + tr1.Protocols.SetUnencryptedHTTP2(true) + t.Cleanup(tr1.CloseIdleConnections) + var err error + tr, err = ConfigureTransports(tr1) + if err != nil { + t.Fatal(err) + } + } + for _, o := range opts { switch o := o.(type) { case func(*http.Transport): @@ -511,54 +632,116 @@ o(tr) case *Transport: tr = o + case roundTripTestMode: + tt.mode = o + case nil: + default: + t.Fatalf("unsupported option %T", o) } } tt.tr = tr + tt.tr1 = tr.TestTransport() - tr.TestSetNewClientConnHook(func(cc *ClientConn) { - tc := newTestClientConnFromClientConn(t, tr, cc) - tt.ccs = append(tt.ccs, tc) + go tt.accept() + + tt.tr.TestSetNewClientConnHook(func(cc *http2.ClientConn) { + nc, ok := cc.TestNetConn().(*synctestNetConn) + if !ok { + return + } + tt.addPending(nc.peer, cc, nil) }) t.Cleanup(func() { + tt.li.Close() synctest.Wait() - if len(tt.ccs) > 0 { - t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) + if len(tt.ccqueue) > 0 { + t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccqueue)) } }) return tt } +func (tt *testTransport) addPending(nc *synctestNetConn, cc *ClientConn, tc *testClientConn) { + tt.ccMu.Lock() + defer tt.ccMu.Unlock() + + for i, p := range tt.ccpending { + if p.nc != nc { + break + } + if p.tc != nil { + p.tc.cc = cc + } else if tc != nil { + tc.cc = p.cc + } else { + panic("found matching ccpending for conn with no tc") + } + tt.ccpending = slices.Delete(tt.ccpending, i, i+1) + return + } + + tt.ccpending = append(tt.ccpending, &testPendingClientConn{ + nc: nc, + cc: cc, + tc: tc, + }) +} + +func (tt *testTransport) accept() { + for { + nc, err := tt.li.Accept() + if err != nil { + return + } + tc := newTestClientConnFromNetConn(tt, nc) + tt.addPending(nc.(*synctestNetConn), nil, tc) + tt.ccqueue = append(tt.ccqueue, tc) + } +} + func (tt *testTransport) hasConn() bool { - return len(tt.ccs) > 0 + return len(tt.ccqueue) > 0 } func (tt *testTransport) getConn() *testClientConn { tt.t.Helper() - if len(tt.ccs) == 0 { + synctest.Wait() + tt.ccMu.Lock() + if len(tt.ccqueue) == 0 { + tt.ccMu.Unlock() tt.t.Fatalf("no new ClientConns created; wanted one") } - tc := tt.ccs[0] - tt.ccs = tt.ccs[1:] - synctest.Wait() + tc := tt.ccqueue[0] + tt.ccqueue = tt.ccqueue[1:] + tt.ccMu.Unlock() tc.readClientPreface() - synctest.Wait() return tc } func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) rt := &testRoundTrip{ - t: tt.t, - donec: make(chan struct{}), + t: tt.t, + donec: make(chan struct{}), + cancel: cancel, } + go func() { defer close(rt.donec) - rt.resp, rt.respErr = tt.tr.RoundTrip(req) + switch tt.mode { + case roundTripXNetHTTP2: + rt.resp, rt.respErr = tt.tr.RoundTrip(req) + case roundTripNetHTTP: + rt.resp, rt.respErr = tt.tr1.RoundTrip(req) + } }() synctest.Wait() tt.t.Cleanup(func() { + rt.cancel() if !rt.done() { return } @@ -570,3 +753,102 @@ return rt } + +type nonblockingReader struct { + mu sync.Mutex + buf bytes.Buffer + err error + waitc chan struct{} + stopc chan struct{} +} + +func newNonblockingReader(reader io.Reader) *nonblockingReader { + r := &nonblockingReader{} + go func() { + buf := make([]byte, 1024) + for { + n, err := reader.Read(buf) + r.mu.Lock() + if n > 0 { + r.buf.Write(buf[:n]) + } + if err != nil { + r.err = err + } + if r.waitc != nil { + close(r.waitc) + r.waitc = nil + } + stopc := r.stopc + r.mu.Unlock() + if err != nil { + return + } + if stopc != nil { + <-stopc + } + } + }() + return r +} + +func (r *nonblockingReader) Read(p []byte) (n int, err error) { + synctest.Wait() + r.mu.Lock() + defer r.mu.Unlock() + n, err = r.buf.Read(p) + if err == io.EOF { + if r.err != nil { + err = r.err + } else { + err = errWouldBlock + } + } + return n, err +} + +func (r *nonblockingReader) waitForData(t testing.TB) time.Duration { + t.Helper() + synctest.Wait() + waitc := func() chan struct{} { + r.mu.Lock() + defer r.mu.Unlock() + if r.buf.Len() > 0 || r.err != nil { + return nil + } + if r.waitc == nil { + r.waitc = make(chan struct{}) + } + return r.waitc + }() + if waitc == nil { + return 0 + } + start := time.Now() + select { + case <-waitc: + case <-time.After(1 * time.Hour): + t.Fatalf("waited an hour for connection data, saw none") + } + return time.Since(start) +} + +func (r *nonblockingReader) stop() { + synctest.Wait() + if r.stopc != nil { + panic("stopping stopped reader") + } + r.stopc = make(chan struct{}) +} + +func (r *nonblockingReader) start() { + synctest.Wait() + if r.stopc == nil { + panic("starting started reader") + } + stopc := r.stopc + r.stopc = nil + close(stopc) +} + +var errWouldBlock = errors.New("would block")
diff --git a/http2/connframes_test.go b/http2/connframes_test.go index d4f0930..a36d0a6 100644 --- a/http2/connframes_test.go +++ b/http2/connframes_test.go
@@ -12,6 +12,7 @@ "reflect" "slices" "testing" + "testing/synctest" . "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -28,7 +29,7 @@ func (tf *testConnFramer) readFrame() Frame { tf.t.Helper() fr, err := tf.fr.ReadFrame() - if err == io.EOF || err == os.ErrDeadlineExceeded { + if err == io.EOF || err == os.ErrDeadlineExceeded || err == errWouldBlock { return nil } if err != nil { @@ -178,7 +179,7 @@ for k, v := range want.header { if !reflect.DeepEqual(v, gotHeader[k]) { - tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k]) + tf.t.Fatalf("got header %q = %q; want %q = %q", k, gotHeader[k], k, v) } } } @@ -309,11 +310,12 @@ func (tf *testConnFramer) wantClosed() { tf.t.Helper() + synctest.Wait() fr, err := tf.fr.ReadFrame() if err == nil { tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr) } - if err == os.ErrDeadlineExceeded { + if err == errWouldBlock { tf.t.Fatalf("connection is not closed; want it to be") } } @@ -324,7 +326,7 @@ if err == nil { tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr) } - if err != os.ErrDeadlineExceeded { + if err != os.ErrDeadlineExceeded && err != io.EOF && err != errWouldBlock { tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err) } }
diff --git a/http2/hpack/encode_test.go b/http2/hpack/encode_test.go index 05f12db..9708cce 100644 --- a/http2/hpack/encode_test.go +++ b/http2/hpack/encode_test.go
@@ -8,6 +8,7 @@ "bytes" "encoding/hex" "fmt" + "io" "math/rand" "reflect" "strings" @@ -384,3 +385,14 @@ } } } + +func TestEncodeZeroAlloc(t *testing.T) { + e := NewEncoder(io.Discard) + s := []byte("some string") + alloc := testing.AllocsPerRun(100, func() { + e.WriteField(HeaderField{Name: string(s), Value: string(s)}) + }) + if alloc != 0 { + t.Errorf("got %v allocs when encoding, want 0", alloc) + } +}
diff --git a/http2/hpack/tables.go b/http2/hpack/tables.go index 8cbdf3f..803fe51 100644 --- a/http2/hpack/tables.go +++ b/http2/hpack/tables.go
@@ -6,6 +6,7 @@ import ( "fmt" + "strings" ) // headerFieldTable implements a list of HeaderFields. @@ -54,10 +55,16 @@ // addEntry adds a new entry. func (t *headerFieldTable) addEntry(f HeaderField) { + // Prevent f from escaping to the heap. + f2 := HeaderField{ + Name: strings.Clone(f.Name), + Value: strings.Clone(f.Value), + Sensitive: f.Sensitive, + } id := uint64(t.len()) + t.evictCount + 1 - t.byName[f.Name] = id - t.byNameValue[pairNameValue{f.Name, f.Value}] = id - t.ents = append(t.ents, f) + t.byName[f2.Name] = id + t.byNameValue[pairNameValue{f2.Name, f2.Value}] = id + t.ents = append(t.ents, f2) } // evictOldest evicts the n oldest entries in the table.
diff --git a/http2/netconn_test.go b/http2/netconn_test.go index 6eba9e2..4572177 100644 --- a/http2/netconn_test.go +++ b/http2/netconn_test.go
@@ -16,8 +16,69 @@ "sync" "testing/synctest" "time" + + "golang.org/x/net/internal/gate" ) +type synctestNetListener struct { + gate gate.Gate + nextPort uint16 + queue []*synctestNetConn + err error +} + +func newSynctestNetListener() *synctestNetListener { + li := &synctestNetListener{ + gate: gate.New(false), + nextPort: 10000, + } + return li +} + +func (li *synctestNetListener) Accept() (net.Conn, error) { + li.gate.WaitAndLock(context.Background()) + defer li.unlock() + if li.err != nil { + return nil, li.err + } + c := li.queue[0] + li.queue = li.queue[1:] + return c, nil +} + +func (li *synctestNetListener) Close() error { + li.gate.Lock() + defer li.unlock() + li.err = net.ErrClosed + for _, c := range li.queue { + c.Close() + } + li.queue = nil + return nil +} + +func (li *synctestNetListener) Addr() net.Addr { + return net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1000")) +} + +func (li *synctestNetListener) newConn() *synctestNetConn { + li.gate.Lock() + defer li.unlock() + cliAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom( + netip.MustParseAddr("127.0.0.1"), + li.nextPort, + )) + li.nextPort++ + cli, srv := synctestNetPipeWithAddrs(cliAddr, li.Addr()) + li.queue = append(li.queue, srv) + return cli +} + +func (li *synctestNetListener) unlock() { + canAccept := len(li.queue) > 0 || li.err != nil + li.gate.Unlock(canAccept) +} + // synctestNetPipe creates an in-memory, full duplex network connection. // Read and write timeouts are managed by the synctest group. // @@ -27,6 +88,10 @@ func synctestNetPipe() (r, w *synctestNetConn) { s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")) s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001")) + return synctestNetPipeWithAddrs(s1addr, s2addr) +} + +func synctestNetPipeWithAddrs(s1addr, s2addr net.Addr) (r, w *synctestNetConn) { s1 := newSynctestNetConnHalf(s1addr) s2 := newSynctestNetConnHalf(s2addr) r = &synctestNetConn{loc: s1, rem: s2}
diff --git a/http2/server_test.go b/http2/server_test.go index 7353085..9e88d4c 100644 --- a/http2/server_test.go +++ b/http2/server_test.go
@@ -2416,22 +2416,7 @@ synctestTest(t, testServer_Rejects_Too_Many_Streams) } func testServer_Rejects_Too_Many_Streams(t testing.TB) { - inHandler := make(chan uint32) - leaveHandler := make(chan bool) - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - var streamID uint32 - if _, err := fmt.Sscanf(r.URL.Path, "/%d", &streamID); err != nil { - t.Errorf("parsing %q: %v", r.URL.Path, err) - } - inHandler <- streamID - <-leaveHandler - }) - defer st.Close() - - // Automatically syncing after every write / before every read - // slows this test down substantially. - st.cc.(*synctestNetConn).autoWait = false - + st := newServerTester(t, nil) st.greet() nextStreamID := uint32(1) streamID := func() uint32 { @@ -2448,15 +2433,11 @@ EndHeaders: true, }) } + var calls []*serverHandlerCall for i := 0; i < DefaultMaxStreams; i++ { sendReq(streamID()) - <-inHandler + calls = append(calls, st.nextHandlerCall()) } - defer func() { - for i := 0; i < DefaultMaxStreams; i++ { - leaveHandler <- true - } - }() // And this one should cross the limit: // (It's also sent as a CONTINUATION, to verify we still track the decoder context, @@ -2477,7 +2458,7 @@ st.wantRSTStream(rejectID, ErrCodeProtocol) // But let a handler finish: - leaveHandler <- true + calls[0].exit() st.sync() st.wantHeaders(wantHeader{ streamID: 1, @@ -2487,8 +2468,9 @@ // And now another stream should be able to start: goodID := streamID() sendReq(goodID) - if got := <-inHandler; got != goodID { - t.Errorf("Got stream %d; want %d", got, goodID) + call := st.nextHandlerCall() + if got, want := call.req.URL.Path, fmt.Sprintf("/%d", goodID); got != want { + t.Errorf("Got request for %q, want %q", got, want) } }
diff --git a/http2/testcert_test.go b/http2/testcert_test.go new file mode 100644 index 0000000..477966e --- /dev/null +++ b/http2/testcert_test.go
@@ -0,0 +1,25 @@ +package http2_test + +import ( + "crypto/tls" + + "golang.org/x/net/internal/testcert" +) + +var ( + testTLSServerConfig = &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{testCert}, + } + testTLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } +) + +var testCert = func() tls.Certificate { + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + panic(err) + } + return cert +}()
diff --git a/http2/transport.go b/http2/transport.go index 2e9c2f6..19553f1 100644 --- a/http2/transport.go +++ b/http2/transport.go
@@ -718,9 +718,6 @@ } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { - if t.transportTestHooks != nil { - return t.newClientConn(nil, singleUse, nil) - } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -2861,6 +2858,9 @@ var seenMaxConcurrentStreams bool err := f.ForeachSetting(func(s Setting) error { + if err := s.Valid(); err != nil { + return err + } switch s.ID { case SettingMaxFrameSize: cc.maxFrameSize = s.Val @@ -2892,9 +2892,6 @@ cc.henc.SetMaxDynamicTableSize(s.Val) cc.peerMaxHeaderTableSize = s.Val case SettingEnableConnectProtocol: - if err := s.Valid(); err != nil { - return err - } // If the peer wants to send us SETTINGS_ENABLE_CONNECT_PROTOCOL, // we require that it do so in the first SETTINGS frame. //
diff --git a/http2/transport_api_test.go b/http2/transport_api_test.go new file mode 100644 index 0000000..b146adb --- /dev/null +++ b/http2/transport_api_test.go
@@ -0,0 +1,834 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file tests the exported, user-facing parts of Transport. +// +// These tests verify that Transport behavior is consistent when using either the +// HTTP/2 implementation in this package (x/net/http2), or when using the +// implementation in net/http/internal/http2. + +package http2_test + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "slices" + "strings" + "testing" + "testing/synctest" + "time" + + "golang.org/x/net/http2" +) + +func synctestTestRoundTrip(t *testing.T, f func(t *testing.T, mode roundTripTestMode)) { + t.Run("netHTTP", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + f(t, roundTripNetHTTP) + }) + }) + t.Run("xNetHTTP2", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + f(t, roundTripXNetHTTP2) + }) + }) +} + +// TestAPITransportDial tests the Transport.Dial, Transport.DialTLS, +// and Transport.TLSClientConfig fields. +func TestAPITransportDial(t *testing.T) { + t.Run("DialTLS/http", func(t *testing.T) { + testAPITransportDial(t, "DialTLS", "http") + }) + t.Run("DialTLS/https", func(t *testing.T) { + testAPITransportDial(t, "DialTLS", "https") + }) + t.Run("DialTLSContext/http", func(t *testing.T) { + testAPITransportDial(t, "DialTLSContext", "http") + }) + t.Run("DialTLSContext/https", func(t *testing.T) { + testAPITransportDial(t, "DialTLSContext", "https") + }) +} +func testAPITransportDial(t *testing.T, name, proto string) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const serverName = "server.tld" + var ( + dialCalled = false + dialServerName = "" + verifyConnectionCalled = false + tt *testTransport + ) + dialFunc := func(network, address string, tlsConf *tls.Config) (net.Conn, error) { + dialCalled = true + dialServerName = tlsConf.ServerName + if got, want := tlsConf.NextProtos, []string{"h2"}; !slices.Equal(got, want) { + t.Errorf("tls.Config.NextProtos = %q, want %q", got, want) + } + conn := tt.li.newConn() + switch proto { + case "http": + return conn, nil + case "https": + tlsConn := tls.Client(conn, tlsConf) + if err := tlsConn.Handshake(); err != nil { + t.Errorf("client TLS handshake: %v", err) + } + return tlsConn, nil + default: + panic("unknown proto: " + proto) + } + } + tt = newTestTransport(t, mode, func(tr2 *http2.Transport) { + tr2.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + VerifyConnection: func(tls.ConnectionState) error { + // We check to see if the Transport's TLSClientConfig + // was used by observing if VerifyConnection was called. + verifyConnectionCalled = true + return nil + }, + } + + switch name { + case "DialTLS": + tr2.DialTLS = dialFunc + tr2.DialTLSContext = nil + case "DialTLSContext": + tr2.DialTLS = nil + tr2.DialTLSContext = func(ctx context.Context, network, address string, tlsConf *tls.Config) (net.Conn, error) { + return dialFunc(network, address, tlsConf) + } + default: + panic("unknown func: " + name) + } + if proto == "http" { + tr2.AllowHTTP = true + } + }) + if proto == "https" { + tt.useTLS = true + } + + req, _ := http.NewRequest("GET", proto+"://"+serverName+"/", nil) + _ = tt.roundTrip(req) + tc := tt.getConn() + tc.wantFrameType(http2.FrameSettings) + + // When we use http.Transport.RoundTrip, it handles the dial and ignores + // the http2.Transport's dialer. + wantDialCalled := mode == roundTripXNetHTTP2 + if dialCalled != wantDialCalled { + t.Errorf("Transport.%v called: %v, want %v", name, dialCalled, wantDialCalled) + } + + // If the VerifyConnection hook is called, this indicates that we + // correctly used the http2.Transport.TLSClientConfig. + if got, want := verifyConnectionCalled, (proto == "https" && wantDialCalled); got != want { + t.Errorf("TLSConfig.VerifyConnection called: %v, want %v", got, want) + } + + // If the dial function is called, it should be provided with a *tls.Config + // with the ServerName filled in correctly. + if dialCalled && dialServerName != serverName { + t.Errorf("TLSConfig.ServerName = %q, want %q", dialServerName, serverName) + } + }) +} + +// TestAPITransportDisableCompression tests the Transport.DisableCompression field. +func TestAPITransportDisableCompression(t *testing.T) { + for _, disable := range []bool{true, false} { + t.Run(fmt.Sprint(disable), func(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.DisableCompression = disable + }) + tc.greet() + + req, _ := http.NewRequest("PUT", "https://dummy.tld/", nil) + _ = tc.roundTrip(req) + + var want []string + if !disable { + want = []string{"gzip"} + } + tc.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + header: http.Header{ + "accept-encoding": want, + }, + }) + }) + }) + } +} + +// TestAPITransportAllowHTTPOff tests the Transport.AllowHTTP field. +// (It only tests AllowHTTP = false, since most other tests use AllowHTTP = true.) +func TestAPITransportAllowHTTPOff(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + tt := newTestTransport(t, mode, func(tr2 *http2.Transport) { + tr2.AllowHTTP = false + }) + + req, _ := http.NewRequest("GET", "http://dummy.tld/", nil) + rt := tt.roundTrip(req) + switch mode { + case roundTripNetHTTP: + // net/http.Transport doesn't respect http2.Transport.AllowHTTP. + // When using a net/http Transport, unencrypted HTTP/2 is allowed when + // the transport Protocols contains UnencryptedHTTP2. + tc := tt.getConn() + tc.wantFrameType(http2.FrameSettings) + tc.wantFrameType(http2.FrameWindowUpdate) + tc.wantFrameType(http2.FrameHeaders) + case roundTripXNetHTTP2: + // x/net/http.Transport only permits http URLs when AllowHTTP is false. + err := rt.err() + want := "unencrypted HTTP/2 not enabled" + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("RoundTrip = %v; want %q", err, want) + } + } + }) +} + +// TestAPITransportMaxHeaderListSize tests the Transport.MaxHeaderListSize field. +func TestAPITransportMaxHeaderListSize(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const size = 20000 + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.MaxHeaderListSize = size + }) + tc.wantSettings(map[http2.SettingID]uint32{ + http2.SettingMaxHeaderListSize: size, + }) + }) +} + +// TestAPITransportMaxDecoderHeaderTableSize tests the Transport.MaxDecoderHeaderTableSize field. +func TestAPITransportMaxDecoderHeaderTableSize(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const size = 10000 + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.MaxDecoderHeaderTableSize = size + }) + tc.wantSettings(map[http2.SettingID]uint32{ + http2.SettingHeaderTableSize: size, + }) + }) +} + +// TestAPITransportMaxEncoderHeaderTableSize should go here, +// but it's difficult to verify the effects of the encoder table. + +// TestAPITransportMaxReadFrameSize tests the Transport.MaxReadFrameSize field. +func TestAPITransportMaxReadFrameSize(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const size = 20000 + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.MaxReadFrameSize = size + }) + tc.wantSettings(map[http2.SettingID]uint32{ + http2.SettingMaxFrameSize: size, + }) + }) +} + +// TestAPITransportStrictMaxConcurrentStreamsEnabled tests the +// Transport.StrictMaxConcurrentStreams field. +func TestAPITransportStrictMaxConcurrentStreamsEnabled(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + tt := newTestTransport(t, mode, func(tr2 *http2.Transport) { + tr2.StrictMaxConcurrentStreams = true + }) + + // Request 1: Sent on a new connection. + // We observe MaxConcurrentStreams = 1. + req1, _ := http.NewRequest("GET", "http://dummy.tld/1", nil) + rt1 := tt.roundTrip(req1) + + tc1 := tt.getConn() + tc1.wantFrameType(http2.FrameSettings) + tc1.wantFrameType(http2.FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/1"}, + }, + }) + + tc1.writeSettings(http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: 1, + }) + tc1.wantFrameType(http2.FrameSettings) + + // Request 2: Blocks, because request 1 is consuming the stream + // concurrency slot. + req2, _ := http.NewRequest("GET", "http://dummy.tld/2", nil) + _ = tt.roundTrip(req2) + tc1.wantIdle() + + // Send a response to request 1. + // Request 2 can now be sent. + tc1.writeHeaders(http2.HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) + tc1.wantHeaders(wantHeader{ + streamID: 3, + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/2"}, + }, + }) + }) +} + +// TestAPITransportStrictMaxConcurrentStreamsDisabled tests the +// Transport.StrictMaxConcurrentStreams field. +func TestAPITransportStrictMaxConcurrentStreamsDisabled(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + tt := newTestTransport(t, mode, func(tr2 *http2.Transport) { + tr2.StrictMaxConcurrentStreams = false + }) + + // Request 1: Sent on a new connection. + // We observe MaxConcurrentStreams = 1. + req1, _ := http.NewRequest("GET", "http://dummy.tld/1", nil) + _ = tt.roundTrip(req1) + + tc1 := tt.getConn() + tc1.wantFrameType(http2.FrameSettings) + tc1.wantFrameType(http2.FrameWindowUpdate) + tc1.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/1"}, + }, + }) + + tc1.writeSettings(http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: 1, + }) + tc1.wantFrameType(http2.FrameSettings) + + // Request 2: Sent on a new connection, because request 1 is consuming the + // first connection's stream concurrency slot. + req2, _ := http.NewRequest("GET", "http://dummy.tld/2", nil) + _ = tt.roundTrip(req2) + tc2 := tt.getConn() + tc2.wantFrameType(http2.FrameSettings) + tc2.wantFrameType(http2.FrameWindowUpdate) + tc2.wantHeaders(wantHeader{ + streamID: 1, + endStream: true, + header: http.Header{ + ":authority": []string{"dummy.tld"}, + ":method": []string{"GET"}, + ":path": []string{"/2"}, + }, + }) + }) +} + +// TestAPITransportIdleConnTimeout tests the Transport.IdleConnTimeout field. +func TestAPITransportIdleConnTimeout(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const idleConnTimeout = 3 * time.Second + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.IdleConnTimeout = idleConnTimeout + }) + tc.greet() + tc.wantIdle() + + closeDelay := tc.connReader.waitForData(t) + tc.wantClosed() + if got, want := closeDelay, idleConnTimeout; got != want { + t.Errorf("time until close: %v, want %v", got, want) + } + }) +} + +// TestAPITransportPingTimeout tests the +// Transport.ReadIdleTimeout and Transport.PingTimeout fields. +func TestAPITransportPingTimeout(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const readIdleTimeout = 3 * time.Second + const pingTimeout = 5 * time.Second + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.ReadIdleTimeout = readIdleTimeout + tr2.PingTimeout = pingTimeout + }) + tc.greet() + tc.wantIdle() + + // PING is sent after ReadIdleTimeout. + pingDelay := tc.connReader.waitForData(t) + tc.wantFrameType(http2.FramePing) + if got, want := pingDelay, readIdleTimeout; got != want { + t.Errorf("time until PING: %v, want %v", got, want) + } + + // Connection is closed after PingTimeout. + closeDelay := tc.connReader.waitForData(t) + tc.wantClosed() + if got, want := closeDelay, pingTimeout; got != want { + t.Errorf("time after PING until close: %v, want %v", got, want) + } + }) +} + +// TestAPITransportWriteByteTimeout tests the Transport.WriteByteTimeout field. +func TestAPITransportWriteByteTimeout(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const writeByteTimeout = 3 * time.Second + tt := newTestTransport(t, mode, func(tr2 *http2.Transport) { + tr2.WriteByteTimeout = writeByteTimeout + }) + + req1, _ := http.NewRequest("GET", "http://dummy.tld/1", nil) + _ = tt.roundTrip(req1) + + tc := tt.getConn() + tc.wantFrameType(http2.FrameSettings) + tc.wantFrameType(http2.FrameWindowUpdate) + tc.wantFrameType(http2.FrameHeaders) + + // Block writes (past the first byte). + // Just sleeping for WriteByteTimeout shouldn't do anything. + tc.connReader.stop() + tc.netconn.SetReadBufferSize(1) + time.Sleep(2 * writeByteTimeout) + tc.wantIdle() + + // Sending a new request will fail after the write timeout. + // + // We need to sleep for 2*writeByteTimeout, since we'll read + // 2 bytes during the first timeout period. (We set a 1-byte buffer, + // the smallest the test conn permits, which still allows for writing a byte.) + req2, _ := http.NewRequest("GET", "http://dummy.tld/", nil) + _ = tt.roundTrip(req2) + time.Sleep(2 * writeByteTimeout) + synctest.Wait() + + // Drain the partial request from the conn, + // after which we can observe that it has been closed. + tc.connReader.start() + io.Copy(io.Discard, tc.connReader) + tc.wantClosed() + }) +} + +// TestAPITransportCountError tests the Transport.CountError field. +func TestAPITransportCountError(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + countError := 0 + tc := newTestClientConn(t, mode, func(tr2 *http2.Transport) { + tr2.CountError = func(errType string) { + countError++ + } + }) + tc.greet() + + tc.netconn.Close() + synctest.Wait() + if countError != 1 { + t.Errorf("after connection error: CountError called %v times, want 1", countError) + } + }) +} + +type testClientConnPool struct { + t *testing.T + li *synctestNetListener + tr2 *http2.Transport + wantReq *http.Request + wantDead *http2.ClientConn + conns []*http2.ClientConn + tlsConf *tls.Config +} + +func (p *testClientConnPool) GetClientConn(req *http.Request, addr string) (cc *http2.ClientConn, err error) { + if p.wantReq == nil { + p.t.Errorf("unexpected call to ClientConnPool.GetClientConn") + } + conn := net.Conn(p.li.newConn()) + if req.URL.Scheme == "https" { + conn = tls.Client(conn, p.tlsConf) + } + cc, err = p.tr2.NewClientConn(conn) + if cc != nil { + p.conns = append(p.conns, cc) + } + p.wantReq = nil + return cc, err +} + +func (p *testClientConnPool) MarkDead(cc *http2.ClientConn) { + if p.wantDead == nil { + p.t.Errorf("unexpected call to ClientConnPool.MarkDead") + } + p.wantDead = nil +} + +func (p *testClientConnPool) check() { + p.t.Helper() + synctest.Wait() + if p.wantReq != nil { + p.t.Errorf("wanted call to ClientConnPool.GetClientConn, got none") + } + if p.wantDead != nil { + p.t.Errorf("wanted call to ClientConnPool.MarkDead, got none") + } +} + +// TestAPITransportConnPool tests the Transport.ConnPool field. +func TestAPITransportConnPool(t *testing.T) { + synctestTestRoundTrip(t, func(t *testing.T, mode roundTripTestMode) { + const idleConnTimeout = 3 * time.Second + pool := &testClientConnPool{ + t: t, + } + tt := newTestTransport(t, mode, func(tr2 *http2.Transport) { + tr2.ConnPool = pool + }) + tt.useTLS = true + pool.tr2 = tt.tr + pool.li = tt.li + pool.tlsConf = testTLSClientConfig.Clone() + pool.tlsConf.NextProtos = []string{"h2"} + + // Send a request. The pool creates a new connection. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + pool.wantReq = req1 + rt := tt.roundTrip(req1) + pool.check() + if len(pool.conns) != 1 { + t.Fatalf("expected pool to create 1 conn, got %v", len(pool.conns)) + } + + // This is the connection created by the ClientConnPool. + tc1 := tt.getConn() + tc1.wantFrameType(http2.FrameSettings) + tc1.wantFrameType(http2.FrameWindowUpdate) + tc1.wantFrameType(http2.FrameHeaders) + + tc1.writeSettings() + tc1.wantFrameType(http2.FrameSettings) // ACK + + tc1.writeHeaders(http2.HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) + + // ClientConnPool.MarkDead is called when the connection closes. + pool.wantDead = pool.conns[0] + tc1.netconn.Close() + pool.check() + }) +} + +// TestAPITransportNewClientConn tests the Transport.NewClientConn method. +func TestAPITransportNewClientConn(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tt := newTestTransport(t, roundTripXNetHTTP2, func(tr2 *http2.Transport) { + // ClientConnState.LastIdle is only set when there is an idle timer. + tr2.IdleConnTimeout = 10 * time.Second + }) + + nc := tt.li.newConn() + cc, err := tt.tr.NewClientConn(nc) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + tc1 := tt.getConn() + tc1.wantFrameType(http2.FrameSettings) + tc1.wantFrameType(http2.FrameWindowUpdate) + tc1.writeSettings(http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: 1, + }) + tc1.wantFrameType(http2.FrameSettings) // ACK + + synctest.Wait() + wantClientConnState(t, cc.State(), http2.ClientConnState{ + MaxConcurrentStreams: 1, + }) + + if got, want := cc.CanTakeNewRequest(), true; got != want { + t.Errorf("cc.CanTakeNewRequest() = %v, want %v", got, want) + } + if got, want := cc.ReserveNewRequest(), true; got != want { + t.Errorf("cc.ReserveNewRequest() = %v, want %v", got, want) + } + // Reservation has consumed the one concurrency slot. + if got, want := cc.CanTakeNewRequest(), false; got != want { + t.Errorf("cc.CanTakeNewRequest() = %v, want %v (sole request slot reserved)", got, want) + } + wantClientConnState(t, cc.State(), http2.ClientConnState{ + StreamsReserved: 1, + MaxConcurrentStreams: 1, + }) + + // Consume the reservation by sending a request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := newTestRoundTrip(t, req1, cc.RoundTrip) + tc1.wantFrameType(http2.FrameHeaders) + wantClientConnState(t, cc.State(), http2.ClientConnState{ + StreamsActive: 1, + MaxConcurrentStreams: 1, + }) + + tc1.writeHeaders(http2.HeadersFrameParam{ + StreamID: 1, + EndHeaders: true, + EndStream: true, + BlockFragment: tc1.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) + if got, want := cc.CanTakeNewRequest(), true; got != want { + t.Errorf("cc.CanTakeNewRequest() = %v, want %v", got, want) + } + wantClientConnState(t, cc.State(), http2.ClientConnState{ + MaxConcurrentStreams: 1, + LastIdle: time.Now(), + }) + + cc.SetDoNotReuse() + wantClientConnState(t, cc.State(), http2.ClientConnState{ + Closing: true, + MaxConcurrentStreams: 1, + LastIdle: time.Now(), + }) + req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt2 := newTestRoundTrip(t, req2, cc.RoundTrip) + if rt2.err() == nil { + t.Fatalf("RoundTrip after SetDoNotReuse: succeeded, want error") + } + + if err := cc.Close(); err != nil { + t.Errorf("cc.Close() = %v, want nil", err) + } + wantClientConnState(t, cc.State(), http2.ClientConnState{ + Closing: true, + Closed: true, + MaxConcurrentStreams: 1, + LastIdle: time.Now(), + }) + }) +} + +// TestAPIClientConnPing tests the ClientConn.Ping method. +func TestAPIClientConnPing(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tt := newTestTransport(t, roundTripXNetHTTP2) + + nc := tt.li.newConn() + cc, err := tt.tr.NewClientConn(nc) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + tc1 := tt.getConn() + tc1.wantFrameType(http2.FrameSettings) + tc1.wantFrameType(http2.FrameWindowUpdate) + tc1.writeSettings(http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: 1, + }) + tc1.wantFrameType(http2.FrameSettings) // ACK + tc1.wantIdle() + + // Ping with successful response. + pingDone := false + var pingErr error + go func() { + pingErr = cc.Ping(context.Background()) + pingDone = true + }() + synctest.Wait() + fr := readFrame[*http2.PingFrame](t, tc1) + if pingDone { + t.Fatalf("cc.Ping() = %v; want to still be running", pingErr) + } + tc1.writePing(true, fr.Data) + synctest.Wait() + if !pingDone { + t.Fatalf("cc.Ping() still running; want to be done") + } + if pingErr != nil { + t.Fatalf("cc.Ping() = %v; want nil", pingErr) + } + + // Ping with no response. + const timeout = 1 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + if err := cc.Ping(ctx); err == nil { + t.Fatalf("cc.Ping() = nil; want error") + } + if got, want := time.Since(start), timeout; got != want { + t.Fatalf("cc.Ping() returned after %v, want %v", got, want) + } + }) +} + +// TestAPIClientConnShutdown tests the ClientConn.Shutdown method. +// Shutdown returns after the last request on the connection completes. +func TestAPIClientConnShutdownSuccess(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tt := newTestTransport(t, roundTripXNetHTTP2) + + nc := tt.li.newConn() + cc, err := tt.tr.NewClientConn(nc) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + tc1 := tt.getConn() + tc1.greet() + + // Start a request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt1 := newTestRoundTrip(t, req1, cc.RoundTrip) + tc1.wantFrameType(http2.FrameHeaders) + + // Start shutdown. + var shutdownErr error + shutdownDone := false + go func() { + shutdownErr = cc.Shutdown(context.Background()) + shutdownDone = true + }() + + synctest.Wait() + if shutdownDone { + t.Fatalf("cc.Shutdown() = %v; want still running with req in flight", err) + } + if rt1.done() { + t.Fatalf("RoundTrip finished; want still running") + } + + // Server terminates the outstanding request, connection shuts down. + tc1.writeRSTStream(1, http2.ErrCodeCancel) + synctest.Wait() + + if !shutdownDone { + t.Fatalf("cc.Shutdown() still running; want to have returned") + } + if shutdownErr != nil { + t.Fatalf("cc.Shutdown() = %v; want nil", err) + } + if err := rt1.err(); err == nil { + t.Fatalf("RoundTrip succeeded; want error") + } + + // We might send a GOAWAY frame before closing. + if fr := tc1.readFrame(); fr != nil { + if _, ok := fr.(*http2.GoAwayFrame); !ok { + t.Fatalf("read frame %v; want GOAWAY or nothing", fr) + } + } + tc1.wantClosed() + }) +} + +// TestAPIClientConnShutdown tests the ClientConn.Shutdown method. +// Shutdown's context expires before the last request on the connection completes. +func TestAPIClientConnShutdownFailure(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + tt := newTestTransport(t, roundTripXNetHTTP2) + + nc := tt.li.newConn() + cc, err := tt.tr.NewClientConn(nc) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + + tc1 := tt.getConn() + tc1.greet() + + // Start a request. + req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + _ = newTestRoundTrip(t, req1, cc.RoundTrip) + tc1.wantFrameType(http2.FrameHeaders) + + // Shutdown's context expires before the request completes. + const timeout = 1 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + if err := cc.Shutdown(ctx); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("cc.Shutdown() = %v; want DeadlineExceeded", err) + } + if got, want := time.Since(start), timeout; got != want { + t.Fatalf("cc.Shutdown() returned after %v, want %v", got, want) + } + + // We might send a GOAWAY frame before closing. + if fr := tc1.readFrame(); fr != nil { + if _, ok := fr.(*http2.GoAwayFrame); !ok { + t.Fatalf("read frame %v; want GOAWAY or nothing", fr) + } + } + tc1.wantIdle() + }) +} + +func wantClientConnState(t *testing.T, a, b http2.ClientConnState) { + t.Helper() + if got, want := a.Closed, b.Closed; got != want { + t.Errorf("ClientConnState.Closed = %v, want %v", got, want) + } + if got, want := a.Closing, b.Closing; got != want { + t.Errorf("ClientConnState.Closing = %v, want %v", got, want) + } + if got, want := a.StreamsActive, b.StreamsActive; got != want { + t.Errorf("ClientConnState.StreamsActive = %v, want %v", got, want) + } + if got, want := a.StreamsReserved, b.StreamsReserved; got != want { + t.Errorf("ClientConnState.StreamsReserved = %v, want %v", got, want) + } + if got, want := a.StreamsPending, b.StreamsPending; got != want { + t.Errorf("ClientConnState.StreamsPending = %v, want %v", got, want) + } + if got, want := a.MaxConcurrentStreams, b.MaxConcurrentStreams; got != want { + t.Errorf("ClientConnState.MaxConcurrentStreams = %v, want %v", got, want) + } + if got, want := a.LastIdle, b.LastIdle; !got.Equal(want) { + t.Errorf("ClientConnState.LastIdle = %v, want %v", got, want) + } +}
diff --git a/http2/transport_test.go b/http2/transport_test.go index d948b88..4610be4 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go
@@ -1858,6 +1858,7 @@ tc.wantRSTStream(rt.streamID(), ErrCodeFlowControl) tc.writeWindowUpdate(0, windowIncrease) + synctest.Wait() tc.wantClosed() } @@ -2777,6 +2778,7 @@ tc := newTestClientConn(t) tc.greet() tc.closeWrite() + synctest.Wait() const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body))) @@ -3213,6 +3215,7 @@ start := time.Now() for streamID := uint32(1); !rt.done(); streamID += 2 { count++ + tc.connReader.waitForData(t) tc.wantHeaders(wantHeader{ streamID: streamID, endStream: true, @@ -3456,6 +3459,7 @@ } tc.writeSettings(Setting{SettingHeaderTableSize, resSize}) + synctest.Wait() if got, want := tc.cc.TestPeerMaxHeaderTableSize(), resSize; got != want { t.Fatalf("peerHeaderTableSize = %d, want %d", got, want) } @@ -5269,7 +5273,7 @@ tc.wantFrameType(FrameHeaders) for i := 0; i < test.hcount; i++ { - if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded { + if fr, err := tc.fr.ReadFrame(); err != errWouldBlock { t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err) } tc.writeHeaders(HeadersFrameParam{ @@ -5430,6 +5434,7 @@ // The server responds to our PING. tc.writePing(true, pf1.Data) + synctest.Wait() // Create yet another request and cancel it. // Still no PING frame; we got a response to the previous one, @@ -5445,6 +5450,7 @@ ":status", "200", ), }) + synctest.Wait() // One more request. // This time we send a PING frame. @@ -5531,9 +5537,10 @@ t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) + tt.useTLS = true // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() + cli := tt.li.newConn() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { t1.TLSNextProto["h2"]("dummy.tld", cliTLS) @@ -5575,9 +5582,10 @@ t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) + tt.useTLS = true // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() + cli := tt.li.newConn() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { t1.TLSNextProto["h2"]("dummy.tld", cliTLS) @@ -5617,9 +5625,10 @@ } t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) + tt.useTLS = true // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() + cli := tt.li.newConn() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { t1.TLSNextProto["h2"]("dummy.tld", cliTLS) @@ -5650,9 +5659,10 @@ t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) + tt.useTLS = true // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe() + cli := tt.li.newConn() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { t1.TLSNextProto["h2"]("dummy.tld", cliTLS) @@ -5677,6 +5687,19 @@ } } +func TestTransportDoNotHangOnZeroMaxFrameSize(t *testing.T) { + synctestTest(t, testTransportDoNotHangOnZeroMaxFrameSize) +} +func testTransportDoNotHangOnZeroMaxFrameSize(t testing.TB) { + tc := newTestClientConn(t) + tc.writeSettings(Setting{ID: SettingMaxFrameSize, Val: 0}) + tc.wantFrameType(FrameSettings) + + req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader("body")) + tc.roundTrip(req) + // Previously, https://go.dev/issue/78476 caused an infinite hang here. +} + func TestExtendedConnectClientWithServerSupport(t *testing.T) { SetDisableExtendedConnectProtocol(t, false) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
diff --git a/http3/http3.go b/http3/http3.go new file mode 100644 index 0000000..72ac090 --- /dev/null +++ b/http3/http3.go
@@ -0,0 +1,31 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3 + +import ( + "net/http" + _ "unsafe" // for linkname + + . "golang.org/x/net/internal/http3" + "golang.org/x/net/quic" +) + +//go:linkname registerHTTP3Server net/http_test.registerHTTP3Server +func registerHTTP3Server(s *http.Server) <-chan string { + listenAddr := make(chan string) + RegisterServer(s, ServerOpts{ + ListenQUIC: func(addr string, config *quic.Config) (*quic.Endpoint, error) { + e, err := quic.Listen("udp", addr, config) + listenAddr <- e.LocalAddr().String() + return e, err + }, + }) + return listenAddr +} + +//go:linkname registerHTTP3Transport net/http_test.registerHTTP3Transport +func registerHTTP3Transport(tr *http.Transport) { + RegisterTransport(tr) +}
diff --git a/internal/http3/http3.go b/internal/http3/http3.go index edbba0c..189e3e7 100644 --- a/internal/http3/http3.go +++ b/internal/http3/http3.go
@@ -4,7 +4,10 @@ package http3 -import "fmt" +import ( + "context" + "fmt" +) // Stream types. // @@ -31,6 +34,14 @@ streamTypeDecoder = streamType(0x03) ) +// canceledCtx is a canceled Context. +// Used for performing non-blocking QUIC operations. +var canceledCtx = func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx +}() + func (stype streamType) String() string { switch stype { case streamTypeRequest:
diff --git a/internal/http3/main_test.go b/internal/http3/main_test.go new file mode 100644 index 0000000..473d618 --- /dev/null +++ b/internal/http3/main_test.go
@@ -0,0 +1,91 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http3_test + +import ( + "fmt" + "os" + "runtime" + "slices" + "strings" + "testing" + "time" +) + +func TestMain(m *testing.M) { + v := m.Run() + if v == 0 && goroutineLeaked() { + os.Exit(1) + } + os.Exit(v) +} + +func runningBenchmarks() bool { + for i, arg := range os.Args { + if strings.HasPrefix(arg, "-test.bench=") && !strings.HasSuffix(arg, "=") { + return true + } + if arg == "-test.bench" && i < len(os.Args)-1 && os.Args[i+1] != "" { + return true + } + } + return false +} + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for g := range strings.SplitSeq(string(buf), "\n\n") { + _, stack, _ := strings.Cut(g, "\n") + stack = strings.TrimSpace(stack) + if stack == "" || + strings.Contains(stack, "testing.(*M).before.func1") || + strings.Contains(stack, "os/signal.signal_recv") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + slices.Sort(gs) + return +} + +// Verify the other tests didn't leave any goroutines running. +func goroutineLeaked() bool { + if testing.Short() || runningBenchmarks() { + // Don't worry about goroutine leaks in -short mode or in + // benchmark mode. Too distracting when there are false positives. + return false + } + + var stackCount map[string]int + for range 5 { + n := 0 + stackCount = make(map[string]int) + gs := interestingGoroutines() + for _, g := range gs { + stackCount[g]++ + n++ + } + if n == 0 { + return false + } + // Wait for goroutines to schedule and die off: + time.Sleep(100 * time.Millisecond) + } + fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n") + for stack, count := range stackCount { + fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack) + } + return true +}
diff --git a/internal/http3/nethttp_test.go b/internal/http3/nethttp_test.go index 8fc0c65..5bfd0af 100644 --- a/internal/http3/nethttp_test.go +++ b/internal/http3/nethttp_test.go
@@ -7,6 +7,7 @@ package http3_test import ( + "context" "crypto/tls" "io" "net/http" @@ -76,23 +77,44 @@ client := &http.Client{ Transport: tr, - Timeout: time.Second, + // Be extra generous with the timeout, to account for smaller builders + // that we use for e.g. plan9. + Timeout: 5 * time.Second, } <-listenAddrSet - req, err := http.NewRequest("GET", "https://"+listenAddr, nil) - if err != nil { + + for range 5 { + req, err := http.NewRequest("GET", "https://"+listenAddr, nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(b, body) { + t.Errorf("got %v, want %v", string(b), string(body)) + } + // TestMain checks that there are no leaked goroutines after tests have + // finished running. + // Over here, we verify that closing the idle connections of a net/http + // Transport will result in HTTP/3 transport closing any UDP sockets + // after there are no longer any open connections. + // We do this in a loop to verify that CloseIdleConnections will not + // prevent transport from creating a new connection should a new dial + // be started. + tr.CloseIdleConnections() + } + // Similarly when a net/http Server shuts down, the HTTP/3 server should + // also follow. + ctx, cancel := context.WithTimeout(t.Context(), 25*time.Millisecond) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { t.Fatal(err) } - resp, err := client.Do(req) - if err != nil { - t.Errorf("got %v err, want nil", err) - } - defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - if !slices.Equal(b, body) { - t.Errorf("got %v, want %v", string(b), string(body)) - } }
diff --git a/internal/http3/quic_test.go b/internal/http3/quic_test.go index 52455a8..4e02c40 100644 --- a/internal/http3/quic_test.go +++ b/internal/http3/quic_test.go
@@ -10,7 +10,6 @@ "crypto/tls" "net" "net/netip" - "runtime" "sync" "testing" "time" @@ -23,10 +22,6 @@ // newLocalQUICEndpoint returns a QUIC Endpoint listening on localhost. func newLocalQUICEndpoint(t *testing.T) *quic.Endpoint { t.Helper() - switch runtime.GOOS { - case "plan9": - t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS) - } conf := &quic.Config{ TLSConfig: testTLSConfig, }
diff --git a/internal/http3/server.go b/internal/http3/server.go index c0e6ba4..28c8cda 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go
@@ -7,6 +7,7 @@ import ( "context" "crypto/tls" + "fmt" "io" "maps" "net/http" @@ -33,7 +34,15 @@ initOnce sync.Once - serveCtx context.Context + serveCtx context.Context + serveCtxCancel context.CancelFunc + + // connClosed is used to signal that a connection has been unregistered + // from activeConns. That way, when shutting down gracefully, the server + // can avoid busy-waiting for activeConns to be empty. + connClosed chan any + mu sync.Mutex // Guards fields below. + activeConns map[*serverConn]struct{} } // netHTTPHandler is an interface that is implemented by @@ -51,6 +60,7 @@ BaseContext() context.Context Addr() string ListenErrHook(err error) + ShutdownContext() context.Context } type ServerOpts struct { @@ -91,6 +101,10 @@ handler: stdHandler, serveCtx: stdHandler.BaseContext(), } + s3.init() + s.RegisterOnShutdown(func() { + s3.shutdown(stdHandler.ShutdownContext()) + }) stdHandler.ListenErrHook(s3.listenAndServe(stdHandler.Addr())) } } @@ -109,6 +123,9 @@ return quic.Listen("udp", addr, config) } } + s.serveCtx, s.serveCtxCancel = context.WithCancel(s.serveCtx) + s.activeConns = make(map[*serverConn]struct{}) + s.connClosed = make(chan any, 1) }) } @@ -128,13 +145,80 @@ // and handles requests from those connections. func (s *server) serve(e *quic.Endpoint) error { s.init() - defer e.Close(s.serveCtx) + defer e.Close(canceledCtx) for { qconn, err := e.Accept(s.serveCtx) if err != nil { return err } - go newServerConn(qconn, s.handler) + go s.newServerConn(qconn, s.handler) + } +} + +// shutdown attempts a graceful shutdown for the server. +func (s *server) shutdown(ctx context.Context) { + // Set a reasonable default in case ctx is nil. + if ctx == nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + } + + // Send GOAWAY frames to all active connections to give a chance for them + // to gracefully terminate. + s.mu.Lock() + for sc := range s.activeConns { + // TODO: Modify x/net/quic stream API so that write errors from context + // deadline are sticky. + go sc.sendGoaway() + } + s.mu.Unlock() + + // Complete shutdown as soon as there are no more active connections or ctx + // is done, whichever comes first. + defer func() { + s.mu.Lock() + defer s.mu.Unlock() + s.serveCtxCancel() + for sc := range s.activeConns { + sc.abort(&connectionError{ + code: errH3NoError, + message: "server is shutting down", + }) + } + }() + noMoreConns := func() bool { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.activeConns) == 0 + } + for { + if noMoreConns() { + return + } + select { + case <-ctx.Done(): + return + case <-s.connClosed: + } + } +} + +func (s *server) registerConn(sc *serverConn) { + s.mu.Lock() + defer s.mu.Unlock() + s.activeConns[sc] = struct{}{} +} + +func (s *server) unregisterConn(sc *serverConn) { + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() + select { + case s.connClosed <- struct{}{}: + default: + // Channel already full. No need to send more values since we are just + // using this channel as a simpler sync.Cond. } } @@ -145,23 +229,32 @@ enc qpackEncoder dec qpackDecoder handler http.Handler + + // For handling shutdown. + controlStream *stream + mu sync.Mutex // Guards everything below. + maxRequestStreamID int64 + goawaySent bool } -func newServerConn(qconn *quic.Conn, handler http.Handler) { +func (s *server) newServerConn(qconn *quic.Conn, handler http.Handler) { sc := &serverConn{ qconn: qconn, handler: handler, } + s.registerConn(sc) + defer s.unregisterConn(sc) sc.enc.init() // Create control stream and send SETTINGS frame. // TODO: Time out on creating stream. - controlStream, err := newConnStream(context.Background(), sc.qconn, streamTypeControl) + var err error + sc.controlStream, err = newConnStream(context.Background(), sc.qconn, streamTypeControl) if err != nil { return } - controlStream.writeSettings() - controlStream.Flush() + sc.controlStream.writeSettings() + sc.controlStream.Flush() sc.acceptStreams(sc.qconn, sc) } @@ -272,7 +365,43 @@ return header, pHeader, nil } +func (sc *serverConn) sendGoaway() { + sc.mu.Lock() + if sc.goawaySent || sc.controlStream == nil { + sc.mu.Unlock() + return + } + sc.goawaySent = true + sc.mu.Unlock() + + // No lock in this section in case writing to stream blocks. This is safe + // since sc.maxRequestStreamID is only updated when sc.goawaySent is false. + sc.controlStream.writeVarint(int64(frameTypeGoaway)) + sc.controlStream.writeVarint(int64(sizeVarint(uint64(sc.maxRequestStreamID)))) + sc.controlStream.writeVarint(sc.maxRequestStreamID) + sc.controlStream.Flush() +} + +// requestShouldGoAway returns true if st has a stream ID that is equal or +// greater than the ID we have sent in a GOAWAY frame, if any. +func (sc *serverConn) requestShouldGoaway(st *stream) bool { + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.goawaySent { + return st.stream.ID() >= sc.maxRequestStreamID + } else { + sc.maxRequestStreamID = max(sc.maxRequestStreamID, st.stream.ID()) + return false + } +} + func (sc *serverConn) handleRequestStream(st *stream) error { + if sc.requestShouldGoaway(st) { + return &streamError{ + code: errH3RequestRejected, + message: "GOAWAY request with equal or lower ID than the stream has been sent", + } + } header, pHeader, err := sc.parseHeader(st) if err != nil { return err @@ -485,6 +614,23 @@ return status >= 100 && status < 200 } +// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode. +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at http://httpwg.org/specs/rfc7231.html#status.codes). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/3, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + func (rw *responseWriter) WriteHeader(statusCode int) { // TODO: handle sending informational status headers (e.g. 103). rw.mu.Lock() @@ -492,6 +638,7 @@ if rw.statusCodeSet { return } + checkWriteHeaderCode(statusCode) // Informational headers can be sent multiple times, and should be flushed // immediately. @@ -579,12 +726,12 @@ // been called before. rw.WriteHeader(http.StatusOK) rw.mu.Lock() + defer rw.mu.Unlock() rw.writeHeaderLockedOnce() if !rw.cannotHaveBody { rw.bw.Write(rw.bb) rw.bb.discard() } - rw.mu.Unlock() rw.st.Flush() }
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index 91887a9..2ed6969 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go
@@ -6,6 +6,7 @@ import ( "errors" + "fmt" "io" "maps" "net/http" @@ -84,7 +85,6 @@ reqStream.writeHeaders(requestHeader(http.Header{ "header-from-client": {"that", "should", "be", "echoed"}, })) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{ ":status": {"204"}, "Header-From-Client": {"that", "should", "be", "echoed"}, @@ -131,13 +131,11 @@ ":scheme": {"https"}, ":path": {"/some/path?query=value&query2=value2#fragment"}, }) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"321"}}) reqStream.wantClosed("request is complete") reqStream = tc.newStream(streamTypeRequest) reqStream.writeHeaders(http.Header{}) // Missing pseudo-header. - synctest.Wait() reqStream.wantError(quic.StreamErrorCode(errH3MessageError)) }) } @@ -157,7 +155,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{ ":status": {"200"}, "Valid-Name": {"valid value"}, @@ -167,6 +164,47 @@ }) } +func TestServerInvalidStatus(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + gotpanic := make(chan bool) + ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(gotpanic) + defer func() { + if e := recover(); e != nil { + got := fmt.Sprintf("%T, %v", e, e) + want := "string, invalid WriteHeader code 0" + if got != want { + t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) + } + gotpanic <- true + // Set an explicit 503. This also tests that the + // WriteHeader call panics before it recorded that an + // explicit value was set. + w.WriteHeader(503) + + // Verify that writing invalid status will not panic if a + // status is already set anyways. + w.WriteHeader(0) + } + }() + w.WriteHeader(0) // Invalid. Will panic. + })) + tc := ts.connect() + tc.greet() + + reqStream := tc.newStream(streamTypeRequest) + reqStream.writeHeaders(requestHeader(nil)) + if !<-gotpanic { + t.Error("expected panic in handler") + } + synctest.Wait() + reqStream.wantSomeHeaders(http.Header{ + ":status": {"503"}, + }) + reqStream.wantClosed("request is complete") + }) +} + func TestServerBody(t *testing.T) { synctest.Test(t, func(t *testing.T) { ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -185,7 +223,6 @@ bodyContent := []byte("some body content that should be echoed") reqStream.writeData(bodyContent) reqStream.stream.stream.CloseWrite() - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) // Small multiple calls to Write will be coalesced into one DATA frame. reqStream.wantData(append([]byte("/"), bodyContent...)) @@ -204,19 +241,43 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(bodyContent) reqStream.wantClosed("request is complete") reqStream = tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(http.Header{":method": {http.MethodHead}})) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") }) } +func TestServerShutdownGoaway(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ts := newTestServer(t, nil) + + tc := ts.connect() + tc.greet() + tc.wantNotClosed("after initial connection handshake") + + requestCount := int64(5) + for range requestCount { + tc.newStream(streamTypeRequest).writeHeaders(requestHeader(nil)) + } + + control := tc.wantStream(streamTypeControl) + control.wantSettings(nil) + + shutdownComplete := make(chan any) + go func() { + ts.s.shutdown(t.Context()) + shutdownComplete <- struct{}{} + }() + control.wantGoaway((requestCount - 1) * 4) // Request stream ID goes from 0, 4, 8, ... + <-shutdownComplete + }) +} + func TestServerHandlerEmpty(t *testing.T) { synctest.Test(t, func(t *testing.T) { ts := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -227,7 +288,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") }) @@ -251,8 +311,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() - respBody := make([]byte, 100) time.Sleep(time.Second) @@ -293,7 +351,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) for _, data := range []string{"a", "bunch", "of", "things", "to", "stream"} { @@ -366,7 +423,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantHeaders(nil) reqStream.wantData(slices.Repeat([]byte("a"), wantWrittenLen)) reqStream.wantClosed("request is complete") @@ -406,14 +462,12 @@ streamIdle <- true // Wait until server responds with HTTP status 100 before sending the // body. - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"100"}}) body := []byte("body that will be echoed back if we get status 100") reqStream.writeData(body) reqStream.stream.stream.CloseWrite() // Receive the server's response after sending the body. - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(body) reqStream.wantClosed("request is complete") @@ -437,7 +491,6 @@ })) // Server rejects it. - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"403"}}) reqStream.wantData(rejectBody) reqStream.wantClosed("request is complete") @@ -469,7 +522,6 @@ reqStream.stream.stream.CloseWrite() // Verify that no HTTP 100 was sent. - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantClosed("request is complete") }) @@ -492,7 +544,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) reqStream.stream.stream.CloseWrite() - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(serverBody) reqStream.wantClosed("request is complete") @@ -503,7 +554,6 @@ reqStream.writeHeaders(requestHeader(http.Header{ "Content-Length": {"0"}, })) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"200"}}) reqStream.wantData(serverBody) reqStream.wantClosed("request is complete") @@ -546,7 +596,6 @@ "Client-Trailer-B": {"valueb"}, "Undeclared-Trailer": {"undeclared"}, // Undeclared trailer should be ignored. }) - synctest.Wait() reqStream.wantHeaders(nil) reqStream.wantClosed("request is complete") }) @@ -587,7 +636,6 @@ "Client-Trailer-B": {"valueb"}, "Undeclared-Trailer": {"undeclared"}, // Undeclared trailer should be ignored. }) - synctest.Wait() reqStream.wantHeaders(nil) reqStream.wantClosed("request is complete") }) @@ -611,7 +659,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{ ":status": {"200"}, "Trailer": {"Server-Trailer-A, Server-Trailer-B, Server-Trailer-C"}, @@ -642,7 +689,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{ ":status": {"200"}, "Trailer": {"Server-Trailer-A, Server-Trailer-B, Server-Trailer-C"}, @@ -778,12 +824,10 @@ })) reqStream.wantIdle("stream is idle until server sends an HTTP 100 status") streamIdle <- true - synctest.Wait() reqStream.wantHeaders(http.Header{":status": {"100"}}) } reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() tt.want.Add(":status", strconv.Itoa(tt.responseStatus)) reqStream.wantHeaders(tt.want) if responseCanHaveBody(tt.responseStatus) { @@ -842,7 +886,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantHeaders(nil) switch { case tt.writeSize > defaultBodyBufferCap: @@ -888,7 +931,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantHeaders(http.Header{ ":status": {"103"}, "Link": { @@ -926,7 +968,6 @@ reqStream := tc.newStream(streamTypeRequest) reqStream.writeHeaders(requestHeader(nil)) - synctest.Wait() reqStream.wantSomeHeaders(http.Header{":status": {"304"}}) reqStream.wantClosed("request is complete") })
diff --git a/internal/http3/transport.go b/internal/http3/transport.go index 4ac1aa0..a99824c 100644 --- a/internal/http3/transport.go +++ b/internal/http3/transport.go
@@ -22,19 +22,17 @@ // TODO: Provide a way to register an HTTP/3 transport with a net/http.transport's // connection pool. type transport struct { - // endpoint is the QUIC endpoint used by connections created by the transport. - // If unset, it is initialized by the first call to Dial. - endpoint *quic.Endpoint - // config is the QUIC configuration used for client connections. - // The config may be nil. - // - // Dial may clone and modify the config. - // The config must not be modified after calling Dial. config *quic.Config - initOnce sync.Once - initErr error + mu sync.Mutex // Guards fields below. + // endpoint is the QUIC endpoint used by connections created by the + // transport. If CloseIdleConnections is called when activeConns is empty, + // endpoint will be unset. If unset, endpoint will be initialized by any + // call to dial. + endpoint *quic.Endpoint + activeConns map[*clientConn]struct{} + inFlightDials int } // netHTTPTransport implements the net/http.dialClientConner interface, @@ -60,33 +58,65 @@ // TODO: most likely, add another arg for transport configuration. func RegisterTransport(tr *http.Transport) { tr3 := &transport{ - config: &quic.Config{ - TLSConfig: tr.TLSClientConfig.Clone(), - }, + // initConfig will clone the tr.TLSClientConfig. + config: initConfig(&quic.Config{ + TLSConfig: tr.TLSClientConfig, + }), + activeConns: make(map[*clientConn]struct{}), } tr.RegisterProtocol("http/3", netHTTPTransport{tr3}) } -func (tr *transport) init() error { - tr.initOnce.Do(func() { - tr.config = initConfig(tr.config) - if tr.endpoint == nil { - tr.endpoint, tr.initErr = quic.Listen("udp", ":0", nil) - } - }) - return tr.initErr +func (tr *transport) incInFlightDials() { + tr.mu.Lock() + defer tr.mu.Unlock() + tr.inFlightDials++ +} + +func (tr *transport) decInFlightDials() { + tr.mu.Lock() + defer tr.mu.Unlock() + tr.inFlightDials-- +} + +func (tr *transport) initEndpoint() (err error) { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.endpoint == nil { + tr.endpoint, err = quic.Listen("udp", ":0", nil) + } + return err } // dial creates a new HTTP/3 client connection. func (tr *transport) dial(ctx context.Context, target string) (*clientConn, error) { - if err := tr.init(); err != nil { + tr.incInFlightDials() + defer tr.decInFlightDials() + + if err := tr.initEndpoint(); err != nil { return nil, err } qconn, err := tr.endpoint.Dial(ctx, "udp", target, tr.config) if err != nil { return nil, err } - return newClientConn(ctx, qconn) + return tr.newClientConn(ctx, qconn) +} + +// CloseIdleConnections is called by net/http.Transport.CloseIdleConnections +// after all existing idle connections are closed using http3.clientConn.Close. +// +// When the transport has no active connections anymore, calling this method +// will make the transport clean up any shared resources that are no longer +// required, such as its QUIC endpoint. +func (tr *transport) CloseIdleConnections() { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.endpoint == nil || len(tr.activeConns) > 0 || tr.inFlightDials > 0 { + return + } + tr.endpoint.Close(canceledCtx) + tr.endpoint = nil } // A clientConn is a client HTTP/3 connection. @@ -100,35 +130,69 @@ dec qpackDecoder } -func newClientConn(ctx context.Context, qconn *quic.Conn) (*clientConn, error) { +func (tr *transport) registerConn(cc *clientConn) { + tr.mu.Lock() + defer tr.mu.Unlock() + tr.activeConns[cc] = struct{}{} +} + +func (tr *transport) unregisterConn(cc *clientConn) { + tr.mu.Lock() + defer tr.mu.Unlock() + delete(tr.activeConns, cc) +} + +func (tr *transport) newClientConn(ctx context.Context, qconn *quic.Conn) (*clientConn, error) { cc := &clientConn{ qconn: qconn, } + tr.registerConn(cc) cc.enc.init() // Create control stream and send SETTINGS frame. controlStream, err := newConnStream(ctx, cc.qconn, streamTypeControl) if err != nil { + tr.unregisterConn(cc) return nil, fmt.Errorf("http3: cannot create control stream: %v", err) } controlStream.writeSettings() controlStream.Flush() - go cc.acceptStreams(qconn, cc) + go func() { + cc.acceptStreams(qconn, cc) + tr.unregisterConn(cc) + }() return cc, nil } -// close closes the connection. -// Any in-flight requests are canceled. -// close does not wait for the peer to acknowledge the connection closing. -func (cc *clientConn) close() error { - // Close the QUIC connection immediately with a status of NO_ERROR. - cc.qconn.Abort(nil) +// TODO: implement the rest of net/http.ClientConn methods beyond Close. +func (cc *clientConn) Close() error { + // We need to use Close rather than Abort on the QUIC connection. + // Otherwise, when a net/http.Transport.CloseIdleConnections is called, it + // might call the http3.transport.CloseIdleConnections prior to all idle + // connections being fully closed; this would make it unable to close its + // QUIC endpoint, making http3.transport.CloseIdleConnections a no-op + // unintentionally. + return cc.qconn.Close() +} - // Return any existing error from the peer, but don't wait for it. - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return cc.qconn.Wait(ctx) +func (cc *clientConn) Err() error { + return nil +} + +func (cc *clientConn) Reserve() error { + return nil +} + +func (cc *clientConn) Release() { +} + +func (cc *clientConn) Available() int { + return 0 +} + +func (cc *clientConn) InFlight() int { + return 0 } func (cc *clientConn) handleControlStream(st *stream) error {
diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go index 7cf44cd..6fc424b 100644 --- a/internal/http3/transport_test.go +++ b/internal/http3/transport_test.go
@@ -98,7 +98,7 @@ return newTestQUICStream(tq.t, st) } -// wantNotClosed asserts that the peer has not closed the connectioln. +// wantNotClosed asserts that the peer has not closed the connection. func (tq *testQUICConn) wantNotClosed(reason string) { t := tq.t t.Helper() @@ -160,9 +160,7 @@ ts.t.Helper() synctest.Wait() qs := ts.stream.stream - ctx, cancel := context.WithCancel(context.Background()) - cancel() - qs.SetReadContext(ctx) + qs.SetReadContext(canceledCtx) if _, err := qs.Read(make([]byte, 1)); !errors.Is(err, context.Canceled) { ts.t.Fatalf("%v: want stream to be idle, but stream has content", reason) } @@ -186,6 +184,7 @@ // If want is nil, the contents of the frame are ignored. func (ts *testQUICStream) wantHeaders(want http.Header) { ts.t.Helper() + synctest.Wait() ftype, err := ts.readFrameHeader() if err != nil { ts.t.Fatalf("want HEADERS frame, got error: %v", err) @@ -221,6 +220,7 @@ // in want are ignored. func (ts *testQUICStream) wantSomeHeaders(want http.Header) { ts.t.Helper() + synctest.Wait() ftype, err := ts.readFrameHeader() if err != nil { ts.t.Fatalf("want HEADERS frame, got error: %v", err) @@ -334,6 +334,36 @@ } } +func (ts *testQUICStream) wantSettings(f func(settingType, value int64) error) { + ts.t.Helper() + synctest.Wait() + if f == nil { + f = func(settingType, value int64) error { return nil } + } + if err := ts.readSettings(f); err != nil { + ts.t.Fatalf("f returned an error: %v", err) + } +} + +func (ts *testQUICStream) wantGoaway(wantID int64) { + ts.t.Helper() + synctest.Wait() + ftype, err := ts.readFrameHeader() + if err != nil { + ts.t.Fatalf("want GOAWAY frame, got error: %v", err) + } + if ftype != frameTypeGoaway { + ts.t.Fatalf("want GOAWAY frame, got: %v", ftype) + } + gotID, err := ts.readVarint() + if err != nil { + ts.t.Fatalf("failed reading GOAWAY frame, got error: %v", err) + } + if gotID != wantID { + ts.t.Fatalf("got stream ID %v from GOAWAY frame, want %v stream ID", gotID, wantID) + } +} + func (ts *testQUICStream) writePushPromise(pushID int64, h http.Header) { ts.t.Helper() headers := ts.encodeHeaders(h) @@ -385,6 +415,7 @@ config: &quic.Config{ TLSConfig: testTLSConfig, }, + activeConns: make(map[*clientConn]struct{}), } cc, err := tr.dial(t.Context(), e2.LocalAddr().String()) @@ -392,7 +423,7 @@ t.Fatal(err) } t.Cleanup(func() { - cc.close() + cc.Close() }) srvConn, err := e2.Accept(t.Context()) if err != nil { @@ -523,11 +554,3 @@ }() return rt } - -// canceledCtx is a canceled Context. -// Used for performing non-blocking QUIC operations. -var canceledCtx = func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return ctx -}()
diff --git a/quic/doc.go b/quic/doc.go index 37b19eb..8d5a78f 100644 --- a/quic/doc.go +++ b/quic/doc.go
@@ -42,4 +42,9 @@ // - Stream send/receive windows are configurable, // but are fixed and do not adapt to available throughput. // - Path MTU discovery is not implemented. +// +// # Security Policy +// +// This package is a work in progress, +// and not yet covered by the Go security policy. package quic
diff --git a/quic/endpoint_test.go b/quic/endpoint_test.go index fe14025..b95f043 100644 --- a/quic/endpoint_test.go +++ b/quic/endpoint_test.go
@@ -11,7 +11,6 @@ "io" "log/slog" "net/netip" - "runtime" "sync" "testing" "testing/synctest" @@ -73,10 +72,6 @@ } func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverConn *Conn) { - switch runtime.GOOS { - case "plan9": - t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS) - } t.Helper() ctx := context.Background() e1 := newLocalEndpoint(t, serverSide, conf1)
diff --git a/quic/errors.go b/quic/errors.go index 25b2f62..1226370 100644 --- a/quic/errors.go +++ b/quic/errors.go
@@ -104,7 +104,7 @@ } // A StreamErrorCode is an application protocol error code (RFC 9000, Section 20.2) -// indicating whay a stream is being closed. +// indicating why a stream is being closed. type StreamErrorCode uint64 func (e StreamErrorCode) Error() string {
diff --git a/quic/retry.go b/quic/retry.go index d70b254..0392ca9 100644 --- a/quic/retry.go +++ b/quic/retry.go
@@ -110,6 +110,9 @@ nonce := append([]byte{}, dstConnID...) nonce = append(nonce, token[:tokenNonceLen]...) ciphertext := token[tokenNonceLen:] + if len(nonce) != rs.aead.NonceSize() { + return nil, false + } plaintext, err := rs.aead.Open(nil, nonce, ciphertext, rs.additionalData(srcConnID, addr)) if err != nil {
diff --git a/quic/retry_test.go b/quic/retry_test.go index d7f6ba9..03edd00 100644 --- a/quic/retry_test.go +++ b/quic/retry_test.go
@@ -218,6 +218,37 @@ errInvalidToken)) } +func TestRetryServerShortDstConnID(t *testing.T) { + synctest.Test(t, testRetryServerShortDstConnID) +} +func testRetryServerShortDstConnID(t *testing.T) { + // Verify that a shorter-than-expected Destination Connection ID does not + // cause a panic due to bad nonce length. https://go.dev/issue/78292. + rt := newRetryServerTest(t) + te := rt.te + te.writeDatagram(&testDatagram{ + packets: []*testPacket{{ + ptype: packetTypeInitial, + num: 1, + version: quicVersion1, + srcConnID: rt.originalSrcConnID, + dstConnID: []byte("short id"), + token: rt.retry.token, + frames: []debugFrame{ + debugFrameCrypto{ + data: rt.initialCrypto, + }, + }, + }}, + paddedSize: 1200, + }) + te.wantDatagram("server closes connection after Initial from wrong address", + initialConnectionCloseDatagram( + []byte("short id"), + rt.originalSrcConnID, + errInvalidToken)) +} + func TestRetryServerIgnoresRetry(t *testing.T) { synctest.Test(t, testRetryServerIgnoresRetry) }
diff --git a/quic/stream.go b/quic/stream.go index 4c63207..383a6c1 100644 --- a/quic/stream.go +++ b/quic/stream.go
@@ -187,6 +187,15 @@ return s } +// ID returns the QUIC stream ID of s. +// +// As specified in RFC 9000, the two least significant bits of a stream ID +// indicate the initiator and directionality of the stream. The upper bits are +// the stream number. +func (s *Stream) ID() int64 { + return int64(s.id) +} + // SetReadContext sets the context used for reads from the stream. // // It is not safe to call SetReadContext concurrently.
diff --git a/quic/udp_darwin.go b/quic/udp_darwin.go index a8677cf..91e8e81 100644 --- a/quic/udp_darwin.go +++ b/quic/udp_darwin.go
@@ -8,8 +8,15 @@ import ( "encoding/binary" + "syscall" +) - "golang.org/x/sys/unix" +// These socket options are available on darwin, but are not in the syscall +// package. Since syscall package is frozen, just define them manually here. +const ( + ip_recvtos = 0x1b + ipv6_recvpktinfo = 0x3d + ipv6_pktinfo = 0x2e ) // See udp.go. @@ -32,7 +39,7 @@ func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { // 32-bit integer. // https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/in_tclass.c#L1062-L1073 - b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 4) + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_TOS, 4) binary.NativeEndian.PutUint32(data, uint32(ecn)) return b }
diff --git a/quic/udp_linux.go b/quic/udp_linux.go index ad0ce9c..08deaf9 100644 --- a/quic/udp_linux.go +++ b/quic/udp_linux.go
@@ -7,7 +7,13 @@ package quic import ( - "golang.org/x/sys/unix" + "syscall" +) + +const ( + ip_recvtos = syscall.IP_RECVTOS + ipv6_recvpktinfo = syscall.IPV6_RECVPKTINFO + ipv6_pktinfo = syscall.IPV6_PKTINFO ) // See udp.go. @@ -27,7 +33,7 @@ } func appendCmsgECNv4(b []byte, ecn ecnBits) []byte { - b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_TOS, 1) + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_TOS, 1) data[0] = byte(ecn) return b }
diff --git a/quic/udp_msg.go b/quic/udp_msg.go index 018e281..1090904 100644 --- a/quic/udp_msg.go +++ b/quic/udp_msg.go
@@ -11,9 +11,8 @@ "net" "net/netip" "sync" + "syscall" "unsafe" - - "golang.org/x/sys/unix" ) // Network interface for platforms using sendmsg/recvmsg with cmsgs. @@ -44,11 +43,11 @@ // // If any of these calls fail, we won't get the requested information. // That's fine, we'll gracefully handle the lack. - unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) - unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ip_recvtos, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_RECVTCLASS, 1) if !localAddr.IsValid() { - unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) - unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1) + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, ipv6_recvpktinfo, 1) } }) @@ -75,10 +74,10 @@ ipv6TclassSize = 4 ) control := make([]byte, 0+ - unix.CmsgSpace(inPktinfoSize)+ - unix.CmsgSpace(in6PktinfoSize)+ - unix.CmsgSpace(ipTOSSize)+ - unix.CmsgSpace(ipv6TclassSize)) + syscall.CmsgSpace(inPktinfoSize)+ + syscall.CmsgSpace(in6PktinfoSize)+ + syscall.CmsgSpace(ipTOSSize)+ + syscall.CmsgSpace(ipv6TclassSize)) for { d := newDatagram() @@ -132,36 +131,35 @@ } func parseControl(d *datagram, control []byte) { - for len(control) > 0 { - hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control) - if err != nil { - return - } - control = remainder - switch hdr.Level { - case unix.IPPROTO_IP: - switch hdr.Type { - case unix.IP_TOS, unix.IP_RECVTOS: + msgs, err := syscall.ParseSocketControlMessage(control) + if err != nil { + return + } + for _, m := range msgs { + switch m.Header.Level { + case syscall.IPPROTO_IP: + switch m.Header.Type { + case syscall.IP_TOS, ip_recvtos: // (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS, // just check for both.) - if ecn, ok := parseIPTOS(data); ok { + if ecn, ok := parseIPTOS(m.Data); ok { d.ecn = ecn } - case unix.IP_PKTINFO: - if a, ok := parseInPktinfo(data); ok { + case syscall.IP_PKTINFO: + if a, ok := parseInPktinfo(m.Data); ok { d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) } } - case unix.IPPROTO_IPV6: - switch hdr.Type { - case unix.IPV6_TCLASS: + case syscall.IPPROTO_IPV6: + switch m.Header.Type { + case syscall.IPV6_TCLASS: // 32-bit integer containing the traffic class field. // The low two bits are the ECN field. - if ecn, ok := parseIPv6TCLASS(data); ok { + if ecn, ok := parseIPv6TCLASS(m.Data); ok { d.ecn = ecn } - case unix.IPV6_PKTINFO: - if a, ok := parseIn6Pktinfo(data); ok { + case ipv6_pktinfo: + if a, ok := parseIn6Pktinfo(m.Data); ok { d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port()) } } @@ -179,7 +177,7 @@ } func appendCmsgECNv6(b []byte, ecn ecnBits) []byte { - b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4) + b, data := appendCmsg(b, syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, 4) binary.NativeEndian.PutUint32(data, uint32(ecn)) return b } @@ -206,7 +204,7 @@ // struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_addr; /* IP Header dst address */ // }; - b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12) + b, data := appendCmsg(b, syscall.IPPROTO_IP, syscall.IP_PKTINFO, 12) ip := src.As4() copy(data[4:], ip[:]) return b @@ -228,7 +226,7 @@ // appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address // for an outbound datagram. func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte { - b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20) + b, data := appendCmsg(b, syscall.IPPROTO_IPV6, ipv6_pktinfo, 20) ip := src.As16() copy(data[0:], ip[:]) return b @@ -238,10 +236,10 @@ // It returns the new buffer, and the data section of the cmsg. func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) { off := len(b) - b = append(b, make([]byte, unix.CmsgSpace(size))...) - h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off])) + b = append(b, make([]byte, syscall.CmsgSpace(size))...) + h := (*syscall.Cmsghdr)(unsafe.Pointer(&b[off])) h.Level = level h.Type = typ - h.SetLen(unix.CmsgLen(size)) - return b, b[off+unix.CmsgSpace(0):][:size] + h.SetLen(syscall.CmsgLen(size)) + return b, b[off+syscall.CmsgSpace(0):][:size] }
diff --git a/quic/udp_other.go b/quic/udp_other.go index 62a82f7..02e4a5f 100644 --- a/quic/udp_other.go +++ b/quic/udp_other.go
@@ -43,7 +43,7 @@ func (c *netUDPConn) Read(f func(*datagram)) { for { dgram := newDatagram() - n, _, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(dgram.b, nil) + n, peerAddr, err := c.c.ReadFromUDPAddrPort(dgram.b) if err != nil { return }
diff --git a/quic/udp_test.go b/quic/udp_test.go index a92aa15..dbd0601 100644 --- a/quic/udp_test.go +++ b/quic/udp_test.go
@@ -127,8 +127,6 @@ if test.srcNet == "udp6" && test.dstNet == "udp" { t.Skipf("%v: no support for mapping IPv4 address to IPv6", runtime.GOOS) } - case "plan9": - t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS) } if runtime.GOARCH == "wasm" && test.srcNet == "udp6" { t.Skipf("%v: IPv6 tests fail when using wasm fake net", runtime.GOARCH)