blob: 39e35cec55620e3edb9fcd3ad57fb6f0dbb4384a [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http_test
import (
"context"
"io"
"net"
"net/http"
"net/http/httptrace"
"testing"
)
func TestTransportPoolConnReusePriorConnection(t *testing.T) {
dt := newTransportDialTester(t, http1Mode)
// First request creates a new connection.
rt1 := dt.roundTrip()
c1 := dt.wantDial()
c1.finish(nil)
rt1.wantDone(c1)
rt1.finish()
// Second request reuses the first connection.
rt2 := dt.roundTrip()
rt2.wantDone(c1)
rt2.finish()
}
func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
dt := newTransportDialTester(t, http1Mode)
// First request creates a new connection.
rt1 := dt.roundTrip()
c1 := dt.wantDial()
c1.finish(nil)
rt1.wantDone(c1)
// Second request is made while the first request is still using its connection,
// so it goes on a new connection.
rt2 := dt.roundTrip()
c2 := dt.wantDial()
c2.finish(nil)
rt2.wantDone(c2)
}
func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
dt := newTransportDialTester(t, http1Mode)
// First request creates a new connection.
rt1 := dt.roundTrip()
c1 := dt.wantDial()
c1.finish(nil)
rt1.wantDone(c1)
// Second request is made while the first request is still using its connection.
// The first connection completes while the second Dial is in progress, so the
// second request uses the first connection.
rt2 := dt.roundTrip()
c2 := dt.wantDial()
rt1.finish()
rt2.wantDone(c1)
// This section is a bit overfitted to the current Transport implementation:
// A third request starts. We have an in-progress dial that was started by rt2,
// but this new request (rt3) is going to ignore it and make a dial of its own.
// rt3 will use the first of these dials that completes.
rt3 := dt.roundTrip()
c3 := dt.wantDial()
c2.finish(nil)
rt3.wantDone(c2)
c3.finish(nil)
}
// A transportDialTester manages a test of a connection's Dials.
type transportDialTester struct {
t *testing.T
cst *clientServerTest
dials chan *transportDialTesterConn // each new conn is sent to this channel
roundTripCount int
dialCount int
}
// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
type transportDialTesterRoundTrip struct {
t *testing.T
roundTripID int // distinguishes RoundTrips in logs
cancel context.CancelFunc // cancels the Request context
reqBody io.WriteCloser // write half of the Request.Body
finished bool
done chan struct{} // closed when RoundTrip returns:w
res *http.Response
err error
conn *transportDialTesterConn
}
// A transportDialTesterConn is a client connection created by the Transport as
// part of a dial test.
type transportDialTesterConn struct {
t *testing.T
connID int // distinguished Dials in logs
ready chan error // sent on to complete the Dial
net.Conn
}
func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
t.Helper()
dt := &transportDialTester{
t: t,
dials: make(chan *transportDialTesterConn),
}
dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Write response headers when we receive a request.
http.NewResponseController(w).EnableFullDuplex()
w.WriteHeader(200)
http.NewResponseController(w).Flush()
// Wait for the client to send the request body,
// to synchronize with the rest of the test.
io.ReadAll(r.Body)
}), func(tr *http.Transport) {
tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
c := &transportDialTesterConn{
t: t,
ready: make(chan error),
}
// Notify the test that a Dial has started,
// and wait for the test to notify us that it should complete.
dt.dials <- c
if err := <-c.ready; err != nil {
return nil, err
}
nc, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// Use the *transportDialTesterConn as the net.Conn,
// to let tests associate requests with connections.
c.Conn = nc
return c, err
}
})
return dt
}
// roundTrip starts a RoundTrip.
// It returns immediately, without waiting for the RoundTrip call to complete.
func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
dt.t.Helper()
ctx, cancel := context.WithCancel(context.Background())
pr, pw := io.Pipe()
rt := &transportDialTesterRoundTrip{
t: dt.t,
roundTripID: dt.roundTripCount,
done: make(chan struct{}),
reqBody: pw,
cancel: cancel,
}
dt.roundTripCount++
dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
dt.t.Cleanup(func() {
rt.cancel()
rt.finish()
})
go func() {
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
rt.conn = info.Conn.(*transportDialTesterConn)
},
})
req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
req.Header.Set("Content-Type", "text/plain")
rt.res, rt.err = dt.cst.tr.RoundTrip(req)
dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
close(rt.done)
}()
return rt
}
// wantDone indicates that a RoundTrip should have returned.
func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
rt.t.Helper()
<-rt.done
if rt.err != nil {
rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
}
if rt.conn != c {
rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
}
}
// finish completes a RoundTrip by sending the request body, consuming the response body,
// and closing the response body.
func (rt *transportDialTesterRoundTrip) finish() {
rt.t.Helper()
if rt.finished {
return
}
rt.finished = true
<-rt.done
if rt.err != nil {
return
}
rt.reqBody.Close()
io.ReadAll(rt.res.Body)
rt.res.Body.Close()
rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
}
// wantDial waits for the Transport to start a Dial.
func (dt *transportDialTester) wantDial() *transportDialTesterConn {
c := <-dt.dials
c.connID = dt.dialCount
dt.dialCount++
dt.t.Logf("Dial %v: started", c.connID)
return c
}
// finish completes a Dial.
func (c *transportDialTesterConn) finish(err error) {
c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
c.ready <- err
close(c.ready)
}