| // Copyright 2011 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Tests for transport.go. |
| // |
| // More tests are in clientserver_test.go (for things testing both client & server for both |
| // HTTP/1 and HTTP/2). This |
| |
| package http_test |
| |
| import ( |
| "bufio" |
| "bytes" |
| "compress/gzip" |
| "context" |
| "crypto/rand" |
| "crypto/tls" |
| "crypto/x509" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "go/token" |
| "internal/nettrace" |
| "io" |
| "log" |
| mrand "math/rand" |
| "net" |
| . "net/http" |
| "net/http/httptest" |
| "net/http/httptrace" |
| "net/http/httputil" |
| "net/http/internal/testcert" |
| "net/textproto" |
| "net/url" |
| "os" |
| "reflect" |
| "runtime" |
| "strconv" |
| "strings" |
| "sync" |
| "sync/atomic" |
| "testing" |
| "testing/iotest" |
| "time" |
| |
| "golang.org/x/net/http/httpguts" |
| ) |
| |
| // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close |
| // and then verify that the final 2 responses get errors back. |
| |
| // hostPortHandler writes back the client's "host:port". |
| var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.FormValue("close") == "true" { |
| w.Header().Set("Connection", "close") |
| } |
| w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) |
| w.Write([]byte(r.RemoteAddr)) |
| }) |
| |
| // testCloseConn is a net.Conn tracked by a testConnSet. |
| type testCloseConn struct { |
| net.Conn |
| set *testConnSet |
| } |
| |
| func (c *testCloseConn) Close() error { |
| c.set.remove(c) |
| return c.Conn.Close() |
| } |
| |
| // testConnSet tracks a set of TCP connections and whether they've |
| // been closed. |
| type testConnSet struct { |
| t *testing.T |
| mu sync.Mutex // guards closed and list |
| closed map[net.Conn]bool |
| list []net.Conn // in order created |
| } |
| |
| func (tcs *testConnSet) insert(c net.Conn) { |
| tcs.mu.Lock() |
| defer tcs.mu.Unlock() |
| tcs.closed[c] = false |
| tcs.list = append(tcs.list, c) |
| } |
| |
| func (tcs *testConnSet) remove(c net.Conn) { |
| tcs.mu.Lock() |
| defer tcs.mu.Unlock() |
| tcs.closed[c] = true |
| } |
| |
| // some tests use this to manage raw tcp connections for later inspection |
| func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { |
| connSet := &testConnSet{ |
| t: t, |
| closed: make(map[net.Conn]bool), |
| } |
| dial := func(n, addr string) (net.Conn, error) { |
| c, err := net.Dial(n, addr) |
| if err != nil { |
| return nil, err |
| } |
| tc := &testCloseConn{c, connSet} |
| connSet.insert(tc) |
| return tc, nil |
| } |
| return connSet, dial |
| } |
| |
| func (tcs *testConnSet) check(t *testing.T) { |
| tcs.mu.Lock() |
| defer tcs.mu.Unlock() |
| for i := 4; i >= 0; i-- { |
| for i, c := range tcs.list { |
| if tcs.closed[c] { |
| continue |
| } |
| if i != 0 { |
| tcs.mu.Unlock() |
| time.Sleep(50 * time.Millisecond) |
| tcs.mu.Lock() |
| continue |
| } |
| t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) |
| } |
| } |
| } |
| |
| func TestReuseRequest(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Write([]byte("{}")) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| req, _ := NewRequest("GET", ts.URL, nil) |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| err = res.Body.Close() |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| res, err = c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| err = res.Body.Close() |
| if err != nil { |
| t.Fatal(err) |
| } |
| } |
| |
| // Two subsequent requests and verify their response is the same. |
| // The response from the server is our own IP:port |
| func TestTransportKeepAlives(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| |
| c := ts.Client() |
| for _, disableKeepAlive := range []bool{false, true} { |
| c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive |
| fetch := func(n int) string { |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) |
| } |
| body, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) |
| } |
| return string(body) |
| } |
| |
| body1 := fetch(1) |
| body2 := fetch(2) |
| |
| bodiesDiffer := body1 != body2 |
| if bodiesDiffer != disableKeepAlive { |
| t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", |
| disableKeepAlive, bodiesDiffer, body1, body2) |
| } |
| } |
| } |
| |
| func TestTransportConnectionCloseOnResponse(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| |
| connSet, testDial := makeTestDial(t) |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| tr.Dial = testDial |
| |
| for _, connectionClose := range []bool{false, true} { |
| fetch := func(n int) string { |
| req := new(Request) |
| var err error |
| req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) |
| if err != nil { |
| t.Fatalf("URL parse error: %v", err) |
| } |
| req.Method = "GET" |
| req.Proto = "HTTP/1.1" |
| req.ProtoMajor = 1 |
| req.ProtoMinor = 1 |
| |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) |
| } |
| defer res.Body.Close() |
| body, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) |
| } |
| return string(body) |
| } |
| |
| body1 := fetch(1) |
| body2 := fetch(2) |
| bodiesDiffer := body1 != body2 |
| if bodiesDiffer != connectionClose { |
| t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", |
| connectionClose, bodiesDiffer, body1, body2) |
| } |
| |
| tr.CloseIdleConnections() |
| } |
| |
| connSet.check(t) |
| } |
| |
| func TestTransportConnectionCloseOnRequest(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| |
| connSet, testDial := makeTestDial(t) |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| tr.Dial = testDial |
| for _, connectionClose := range []bool{false, true} { |
| fetch := func(n int) string { |
| req := new(Request) |
| var err error |
| req.URL, err = url.Parse(ts.URL) |
| if err != nil { |
| t.Fatalf("URL parse error: %v", err) |
| } |
| req.Method = "GET" |
| req.Proto = "HTTP/1.1" |
| req.ProtoMajor = 1 |
| req.ProtoMinor = 1 |
| req.Close = connectionClose |
| |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) |
| } |
| if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(connectionClose); got != want { |
| t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v", |
| connectionClose, got, !connectionClose) |
| } |
| body, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) |
| } |
| return string(body) |
| } |
| |
| body1 := fetch(1) |
| body2 := fetch(2) |
| bodiesDiffer := body1 != body2 |
| if bodiesDiffer != connectionClose { |
| t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", |
| connectionClose, bodiesDiffer, body1, body2) |
| } |
| |
| tr.CloseIdleConnections() |
| } |
| |
| connSet.check(t) |
| } |
| |
| // if the Transport's DisableKeepAlives is set, all requests should |
| // send Connection: close. |
| // HTTP/1-only (Connection: close doesn't exist in h2) |
| func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| |
| c := ts.Client() |
| c.Transport.(*Transport).DisableKeepAlives = true |
| |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| if res.Header.Get("X-Saw-Close") != "true" { |
| t.Errorf("handler didn't see Connection: close ") |
| } |
| } |
| |
| // Test that Transport only sends one "Connection: close", regardless of |
| // how "close" was indicated. |
| func TestTransportRespectRequestWantsClose(t *testing.T) { |
| tests := []struct { |
| disableKeepAlives bool |
| close bool |
| }{ |
| {disableKeepAlives: false, close: false}, |
| {disableKeepAlives: false, close: true}, |
| {disableKeepAlives: true, close: false}, |
| {disableKeepAlives: true, close: true}, |
| } |
| |
| for _, tc := range tests { |
| t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), |
| func(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| |
| c := ts.Client() |
| c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives |
| req, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| count := 0 |
| trace := &httptrace.ClientTrace{ |
| WroteHeaderField: func(key string, field []string) { |
| if key != "Connection" { |
| return |
| } |
| if httpguts.HeaderValuesContainsToken(field, "close") { |
| count += 1 |
| } |
| }, |
| } |
| req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) |
| req.Close = tc.close |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want { |
| t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) |
| } |
| }) |
| } |
| |
| } |
| |
| func TestTransportIdleCacheKeys(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { |
| t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) |
| } |
| |
| resp, err := c.Get(ts.URL) |
| if err != nil { |
| t.Error(err) |
| } |
| io.ReadAll(resp.Body) |
| |
| keys := tr.IdleConnKeysForTesting() |
| if e, g := 1, len(keys); e != g { |
| t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) |
| } |
| |
| if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { |
| t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) |
| } |
| |
| tr.CloseIdleConnections() |
| if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { |
| t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) |
| } |
| } |
| |
| // Tests that the HTTP transport re-uses connections when a client |
| // reads to the end of a response Body without closing it. |
| func TestTransportReadToEndReusesConn(t *testing.T) { |
| defer afterTest(t) |
| const msg = "foobar" |
| |
| var addrSeen map[string]int |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| addrSeen[r.RemoteAddr]++ |
| if r.URL.Path == "/chunked/" { |
| w.WriteHeader(200) |
| w.(Flusher).Flush() |
| } else { |
| w.Header().Set("Content-Length", strconv.Itoa(len(msg))) |
| w.WriteHeader(200) |
| } |
| w.Write([]byte(msg)) |
| })) |
| defer ts.Close() |
| |
| buf := make([]byte, len(msg)) |
| |
| for pi, path := range []string{"/content-length/", "/chunked/"} { |
| wantLen := []int{len(msg), -1}[pi] |
| addrSeen = make(map[string]int) |
| for i := 0; i < 3; i++ { |
| res, err := Get(ts.URL + path) |
| if err != nil { |
| t.Errorf("Get %s: %v", path, err) |
| continue |
| } |
| // We want to close this body eventually (before the |
| // defer afterTest at top runs), but not before the |
| // len(addrSeen) check at the bottom of this test, |
| // since Closing this early in the loop would risk |
| // making connections be re-used for the wrong reason. |
| defer res.Body.Close() |
| |
| if res.ContentLength != int64(wantLen) { |
| t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) |
| } |
| n, err := res.Body.Read(buf) |
| if n != len(msg) || err != io.EOF { |
| t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) |
| } |
| } |
| if len(addrSeen) != 1 { |
| t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) |
| } |
| } |
| } |
| |
| func TestTransportMaxPerHostIdleConns(t *testing.T) { |
| defer afterTest(t) |
| stop := make(chan struct{}) // stop marks the exit of main Test goroutine |
| defer close(stop) |
| |
| resch := make(chan string) |
| gotReq := make(chan bool) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| gotReq <- true |
| var msg string |
| select { |
| case <-stop: |
| return |
| case msg = <-resch: |
| } |
| _, err := w.Write([]byte(msg)) |
| if err != nil { |
| t.Errorf("Write: %v", err) |
| return |
| } |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| maxIdleConnsPerHost := 2 |
| tr.MaxIdleConnsPerHost = maxIdleConnsPerHost |
| |
| // Start 3 outstanding requests and wait for the server to get them. |
| // Their responses will hang until we write to resch, though. |
| donech := make(chan bool) |
| doReq := func() { |
| defer func() { |
| select { |
| case <-stop: |
| return |
| case donech <- t.Failed(): |
| } |
| }() |
| resp, err := c.Get(ts.URL) |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| if _, err := io.ReadAll(resp.Body); err != nil { |
| t.Errorf("ReadAll: %v", err) |
| return |
| } |
| } |
| go doReq() |
| <-gotReq |
| go doReq() |
| <-gotReq |
| go doReq() |
| <-gotReq |
| |
| if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { |
| t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) |
| } |
| |
| resch <- "res1" |
| <-donech |
| keys := tr.IdleConnKeysForTesting() |
| if e, g := 1, len(keys); e != g { |
| t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) |
| } |
| addr := ts.Listener.Addr().String() |
| cacheKey := "|http|" + addr |
| if keys[0] != cacheKey { |
| t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) |
| } |
| if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { |
| t.Errorf("after first response, expected %d idle conns; got %d", e, g) |
| } |
| |
| resch <- "res2" |
| <-donech |
| if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { |
| t.Errorf("after second response, idle conns = %d; want %d", g, w) |
| } |
| |
| resch <- "res3" |
| <-donech |
| if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { |
| t.Errorf("after third response, idle conns = %d; want %d", g, w) |
| } |
| } |
| |
| func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| _, err := w.Write([]byte("foo")) |
| if err != nil { |
| t.Fatalf("Write: %v", err) |
| } |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| dialStarted := make(chan struct{}) |
| stallDial := make(chan struct{}) |
| tr.Dial = func(network, addr string) (net.Conn, error) { |
| dialStarted <- struct{}{} |
| <-stallDial |
| return net.Dial(network, addr) |
| } |
| |
| tr.DisableKeepAlives = true |
| tr.MaxConnsPerHost = 1 |
| |
| preDial := make(chan struct{}) |
| reqComplete := make(chan struct{}) |
| doReq := func(reqId string) { |
| req, _ := NewRequest("GET", ts.URL, nil) |
| trace := &httptrace.ClientTrace{ |
| GetConn: func(hostPort string) { |
| preDial <- struct{}{} |
| }, |
| } |
| req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) |
| resp, err := tr.RoundTrip(req) |
| if err != nil { |
| t.Errorf("unexpected error for request %s: %v", reqId, err) |
| } |
| _, err = io.ReadAll(resp.Body) |
| if err != nil { |
| t.Errorf("unexpected error for request %s: %v", reqId, err) |
| } |
| reqComplete <- struct{}{} |
| } |
| // get req1 to dial-in-progress |
| go doReq("req1") |
| <-preDial |
| <-dialStarted |
| |
| // get req2 to waiting on conns per host to go down below max |
| go doReq("req2") |
| <-preDial |
| select { |
| case <-dialStarted: |
| t.Error("req2 dial started while req1 dial in progress") |
| return |
| default: |
| } |
| |
| // let req1 complete |
| stallDial <- struct{}{} |
| <-reqComplete |
| |
| // let req2 complete |
| <-dialStarted |
| stallDial <- struct{}{} |
| <-reqComplete |
| } |
| |
| func TestTransportMaxConnsPerHost(t *testing.T) { |
| defer afterTest(t) |
| CondSkipHTTP2(t) |
| |
| h := HandlerFunc(func(w ResponseWriter, r *Request) { |
| _, err := w.Write([]byte("foo")) |
| if err != nil { |
| t.Fatalf("Write: %v", err) |
| } |
| }) |
| |
| testMaxConns := func(scheme string, ts *httptest.Server) { |
| defer ts.Close() |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| tr.MaxConnsPerHost = 1 |
| if err := ExportHttp2ConfigureTransport(tr); err != nil { |
| t.Fatalf("ExportHttp2ConfigureTransport: %v", err) |
| } |
| |
| mu := sync.Mutex{} |
| var conns []net.Conn |
| var dialCnt, gotConnCnt, tlsHandshakeCnt int32 |
| tr.Dial = func(network, addr string) (net.Conn, error) { |
| atomic.AddInt32(&dialCnt, 1) |
| c, err := net.Dial(network, addr) |
| mu.Lock() |
| defer mu.Unlock() |
| conns = append(conns, c) |
| return c, err |
| } |
| |
| doReq := func() { |
| trace := &httptrace.ClientTrace{ |
| GotConn: func(connInfo httptrace.GotConnInfo) { |
| if !connInfo.Reused { |
| atomic.AddInt32(&gotConnCnt, 1) |
| } |
| }, |
| TLSHandshakeStart: func() { |
| atomic.AddInt32(&tlsHandshakeCnt, 1) |
| }, |
| } |
| req, _ := NewRequest("GET", ts.URL, nil) |
| req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) |
| |
| resp, err := c.Do(req) |
| if err != nil { |
| t.Fatalf("request failed: %v", err) |
| } |
| defer resp.Body.Close() |
| _, err = io.ReadAll(resp.Body) |
| if err != nil { |
| t.Fatalf("read body failed: %v", err) |
| } |
| } |
| |
| wg := sync.WaitGroup{} |
| for i := 0; i < 10; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| doReq() |
| }() |
| } |
| wg.Wait() |
| |
| expected := int32(tr.MaxConnsPerHost) |
| if dialCnt != expected { |
| t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) |
| } |
| if gotConnCnt != expected { |
| t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) |
| } |
| if ts.TLS != nil && tlsHandshakeCnt != expected { |
| t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) |
| } |
| |
| if t.Failed() { |
| t.FailNow() |
| } |
| |
| mu.Lock() |
| for _, c := range conns { |
| c.Close() |
| } |
| conns = nil |
| mu.Unlock() |
| tr.CloseIdleConnections() |
| |
| doReq() |
| expected++ |
| if dialCnt != expected { |
| t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) |
| } |
| if gotConnCnt != expected { |
| t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) |
| } |
| if ts.TLS != nil && tlsHandshakeCnt != expected { |
| t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) |
| } |
| } |
| |
| testMaxConns("http", httptest.NewServer(h)) |
| testMaxConns("https", httptest.NewTLSServer(h)) |
| |
| ts := httptest.NewUnstartedServer(h) |
| ts.TLS = &tls.Config{NextProtos: []string{"h2"}} |
| ts.StartTLS() |
| testMaxConns("http2", ts) |
| } |
| |
| func TestTransportRemovesDeadIdleConnections(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| io.WriteString(w, r.RemoteAddr) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| doReq := func(name string) string { |
| // Do a POST instead of a GET to prevent the Transport's |
| // idempotent request retry logic from kicking in... |
| res, err := c.Post(ts.URL, "", nil) |
| if err != nil { |
| t.Fatalf("%s: %v", name, err) |
| } |
| if res.StatusCode != 200 { |
| t.Fatalf("%s: %v", name, res.Status) |
| } |
| defer res.Body.Close() |
| slurp, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("%s: %v", name, err) |
| } |
| return string(slurp) |
| } |
| |
| first := doReq("first") |
| keys1 := tr.IdleConnKeysForTesting() |
| |
| ts.CloseClientConnections() |
| |
| var keys2 []string |
| if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { |
| keys2 = tr.IdleConnKeysForTesting() |
| return len(keys2) == 0 |
| }) { |
| t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) |
| } |
| |
| second := doReq("second") |
| if first == second { |
| t.Errorf("expected a different connection between requests. got %q both times", first) |
| } |
| } |
| |
| // Test that the Transport notices when a server hangs up on its |
| // unexpectedly (a keep-alive connection is closed). |
| func TestTransportServerClosingUnexpectedly(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| ts := httptest.NewServer(hostPortHandler) |
| defer ts.Close() |
| c := ts.Client() |
| |
| fetch := func(n, retries int) string { |
| condFatalf := func(format string, arg ...any) { |
| if retries <= 0 { |
| t.Fatalf(format, arg...) |
| } |
| t.Logf("retrying shortly after expected error: "+format, arg...) |
| time.Sleep(time.Second / time.Duration(retries)) |
| } |
| for retries >= 0 { |
| retries-- |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| condFatalf("error in req #%d, GET: %v", n, err) |
| continue |
| } |
| body, err := io.ReadAll(res.Body) |
| if err != nil { |
| condFatalf("error in req #%d, ReadAll: %v", n, err) |
| continue |
| } |
| res.Body.Close() |
| return string(body) |
| } |
| panic("unreachable") |
| } |
| |
| body1 := fetch(1, 0) |
| body2 := fetch(2, 0) |
| |
| // Close all the idle connections in a way that's similar to |
| // the server hanging up on us. We don't use |
| // httptest.Server.CloseClientConnections because it's |
| // best-effort and stops blocking after 5 seconds. On a loaded |
| // machine running many tests concurrently it's possible for |
| // that method to be async and cause the body3 fetch below to |
| // run on an old connection. This function is synchronous. |
| ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) |
| |
| body3 := fetch(3, 5) |
| |
| if body1 != body2 { |
| t.Errorf("expected body1 and body2 to be equal") |
| } |
| if body2 == body3 { |
| t.Errorf("expected body2 and body3 to be different") |
| } |
| } |
| |
| // Test for https://golang.org/issue/2616 (appropriate issue number) |
| // This fails pretty reliably with GOMAXPROCS=100 or something high. |
| func TestStressSurpriseServerCloses(t *testing.T) { |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping test in short mode") |
| } |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Content-Length", "5") |
| w.Header().Set("Content-Type", "text/plain") |
| w.Write([]byte("Hello")) |
| w.(Flusher).Flush() |
| conn, buf, _ := w.(Hijacker).Hijack() |
| buf.Flush() |
| conn.Close() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| // Do a bunch of traffic from different goroutines. Send to activityc |
| // after each request completes, regardless of whether it failed. |
| // If these are too high, OS X exhausts its ephemeral ports |
| // and hangs waiting for them to transition TCP states. That's |
| // not what we want to test. TODO(bradfitz): use an io.Pipe |
| // dialer for this test instead? |
| const ( |
| numClients = 20 |
| reqsPerClient = 25 |
| ) |
| activityc := make(chan bool) |
| for i := 0; i < numClients; i++ { |
| go func() { |
| for i := 0; i < reqsPerClient; i++ { |
| res, err := c.Get(ts.URL) |
| if err == nil { |
| // We expect errors since the server is |
| // hanging up on us after telling us to |
| // send more requests, so we don't |
| // actually care what the error is. |
| // But we want to close the body in cases |
| // where we won the race. |
| res.Body.Close() |
| } |
| if !<-activityc { // Receives false when close(activityc) is executed |
| return |
| } |
| } |
| }() |
| } |
| |
| // Make sure all the request come back, one way or another. |
| for i := 0; i < numClients*reqsPerClient; i++ { |
| select { |
| case activityc <- true: |
| case <-time.After(5 * time.Second): |
| close(activityc) |
| t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") |
| } |
| } |
| } |
| |
| // TestTransportHeadResponses verifies that we deal with Content-Lengths |
| // with no bodies properly |
| func TestTransportHeadResponses(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method != "HEAD" { |
| panic("expected HEAD; got " + r.Method) |
| } |
| w.Header().Set("Content-Length", "123") |
| w.WriteHeader(200) |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| for i := 0; i < 2; i++ { |
| res, err := c.Head(ts.URL) |
| if err != nil { |
| t.Errorf("error on loop %d: %v", i, err) |
| continue |
| } |
| if e, g := "123", res.Header.Get("Content-Length"); e != g { |
| t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) |
| } |
| if e, g := int64(123), res.ContentLength; e != g { |
| t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) |
| } |
| if all, err := io.ReadAll(res.Body); err != nil { |
| t.Errorf("loop %d: Body ReadAll: %v", i, err) |
| } else if len(all) != 0 { |
| t.Errorf("Bogus body %q", all) |
| } |
| } |
| } |
| |
| // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding |
| // on responses to HEAD requests. |
| func TestTransportHeadChunkedResponse(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method != "HEAD" { |
| panic("expected HEAD; got " + r.Method) |
| } |
| w.Header().Set("Transfer-Encoding", "chunked") // client should ignore |
| w.Header().Set("x-client-ipport", r.RemoteAddr) |
| w.WriteHeader(200) |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| // Ensure that we wait for the readLoop to complete before |
| // calling Head again |
| didRead := make(chan bool) |
| SetReadLoopBeforeNextReadHook(func() { didRead <- true }) |
| defer SetReadLoopBeforeNextReadHook(nil) |
| |
| res1, err := c.Head(ts.URL) |
| <-didRead |
| |
| if err != nil { |
| t.Fatalf("request 1 error: %v", err) |
| } |
| |
| res2, err := c.Head(ts.URL) |
| <-didRead |
| |
| if err != nil { |
| t.Fatalf("request 2 error: %v", err) |
| } |
| if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { |
| t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) |
| } |
| } |
| |
| var roundTripTests = []struct { |
| accept string |
| expectAccept string |
| compressed bool |
| }{ |
| // Requests with no accept-encoding header use transparent compression |
| {"", "gzip", false}, |
| // Requests with other accept-encoding should pass through unmodified |
| {"foo", "foo", false}, |
| // Requests with accept-encoding == gzip should be passed through |
| {"gzip", "gzip", true}, |
| } |
| |
| // Test that the modification made to the Request by the RoundTripper is cleaned up |
| func TestRoundTripGzip(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| const responseBody = "test response body" |
| ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { |
| accept := req.Header.Get("Accept-Encoding") |
| if expect := req.FormValue("expect_accept"); accept != expect { |
| t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", |
| req.FormValue("testnum"), accept, expect) |
| } |
| if accept == "gzip" { |
| rw.Header().Set("Content-Encoding", "gzip") |
| gz := gzip.NewWriter(rw) |
| gz.Write([]byte(responseBody)) |
| gz.Close() |
| } else { |
| rw.Header().Set("Content-Encoding", accept) |
| rw.Write([]byte(responseBody)) |
| } |
| })) |
| defer ts.Close() |
| tr := ts.Client().Transport.(*Transport) |
| |
| for i, test := range roundTripTests { |
| // Test basic request (no accept-encoding) |
| req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) |
| if test.accept != "" { |
| req.Header.Set("Accept-Encoding", test.accept) |
| } |
| res, err := tr.RoundTrip(req) |
| if err != nil { |
| t.Errorf("%d. RoundTrip: %v", i, err) |
| continue |
| } |
| var body []byte |
| if test.compressed { |
| var r *gzip.Reader |
| r, err = gzip.NewReader(res.Body) |
| if err != nil { |
| t.Errorf("%d. gzip NewReader: %v", i, err) |
| continue |
| } |
| body, err = io.ReadAll(r) |
| res.Body.Close() |
| } else { |
| body, err = io.ReadAll(res.Body) |
| } |
| if err != nil { |
| t.Errorf("%d. Error: %q", i, err) |
| continue |
| } |
| if g, e := string(body), responseBody; g != e { |
| t.Errorf("%d. body = %q; want %q", i, g, e) |
| } |
| if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { |
| t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) |
| } |
| if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { |
| t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) |
| } |
| } |
| |
| } |
| |
| func TestTransportGzip(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" |
| const nRandBytes = 1024 * 1024 |
| ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { |
| if req.Method == "HEAD" { |
| if g := req.Header.Get("Accept-Encoding"); g != "" { |
| t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) |
| } |
| return |
| } |
| if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { |
| t.Errorf("Accept-Encoding = %q, want %q", g, e) |
| } |
| rw.Header().Set("Content-Encoding", "gzip") |
| |
| var w io.Writer = rw |
| var buf bytes.Buffer |
| if req.FormValue("chunked") == "0" { |
| w = &buf |
| defer io.Copy(rw, &buf) |
| defer func() { |
| rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) |
| }() |
| } |
| gz := gzip.NewWriter(w) |
| gz.Write([]byte(testString)) |
| if req.FormValue("body") == "large" { |
| io.CopyN(gz, rand.Reader, nRandBytes) |
| } |
| gz.Close() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| for _, chunked := range []string{"1", "0"} { |
| // First fetch something large, but only read some of it. |
| res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) |
| if err != nil { |
| t.Fatalf("large get: %v", err) |
| } |
| buf := make([]byte, len(testString)) |
| n, err := io.ReadFull(res.Body, buf) |
| if err != nil { |
| t.Fatalf("partial read of large response: size=%d, %v", n, err) |
| } |
| if e, g := testString, string(buf); e != g { |
| t.Errorf("partial read got %q, expected %q", g, e) |
| } |
| res.Body.Close() |
| // Read on the body, even though it's closed |
| n, err = res.Body.Read(buf) |
| if n != 0 || err == nil { |
| t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) |
| } |
| |
| // Then something small. |
| res, err = c.Get(ts.URL + "/?chunked=" + chunked) |
| if err != nil { |
| t.Fatal(err) |
| } |
| body, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if g, e := string(body), testString; g != e { |
| t.Fatalf("body = %q; want %q", g, e) |
| } |
| if g, e := res.Header.Get("Content-Encoding"), ""; g != e { |
| t.Fatalf("Content-Encoding = %q; want %q", g, e) |
| } |
| |
| // Read on the body after it's been fully read: |
| n, err = res.Body.Read(buf) |
| if n != 0 || err == nil { |
| t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) |
| } |
| res.Body.Close() |
| n, err = res.Body.Read(buf) |
| if n != 0 || err == nil { |
| t.Errorf("expected Read error after Close; got %d, %v", n, err) |
| } |
| } |
| |
| // And a HEAD request too, because they're always weird. |
| res, err := c.Head(ts.URL) |
| if err != nil { |
| t.Fatalf("Head: %v", err) |
| } |
| if res.StatusCode != 200 { |
| t.Errorf("Head status=%d; want=200", res.StatusCode) |
| } |
| } |
| |
| // If a request has Expect:100-continue header, the request blocks sending body until the first response. |
| // Premature consumption of the request body should not be occurred. |
| func TestTransportExpect100Continue(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| |
| ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { |
| switch req.URL.Path { |
| case "/100": |
| // This endpoint implicitly responds 100 Continue and reads body. |
| if _, err := io.Copy(io.Discard, req.Body); err != nil { |
| t.Error("Failed to read Body", err) |
| } |
| rw.WriteHeader(StatusOK) |
| case "/200": |
| // Go 1.5 adds Connection: close header if the client expect |
| // continue but not entire request body is consumed. |
| rw.WriteHeader(StatusOK) |
| case "/500": |
| rw.WriteHeader(StatusInternalServerError) |
| case "/keepalive": |
| // This hijacked endpoint responds error without Connection:close. |
| _, bufrw, err := rw.(Hijacker).Hijack() |
| if err != nil { |
| log.Fatal(err) |
| } |
| bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") |
| bufrw.WriteString("Content-Length: 0\r\n\r\n") |
| bufrw.Flush() |
| case "/timeout": |
| // This endpoint tries to read body without 100 (Continue) response. |
| // After ExpectContinueTimeout, the reading will be started. |
| conn, bufrw, err := rw.(Hijacker).Hijack() |
| if err != nil { |
| log.Fatal(err) |
| } |
| if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { |
| t.Error("Failed to read Body", err) |
| } |
| bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") |
| bufrw.Flush() |
| conn.Close() |
| } |
| |
| })) |
| defer ts.Close() |
| |
| tests := []struct { |
| path string |
| body []byte |
| sent int |
| status int |
| }{ |
| {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. |
| {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. |
| {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. |
| {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. |
| {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. |
| } |
| |
| c := ts.Client() |
| for i, v := range tests { |
| tr := &Transport{ |
| ExpectContinueTimeout: 2 * time.Second, |
| } |
| defer tr.CloseIdleConnections() |
| c.Transport = tr |
| body := bytes.NewReader(v.body) |
| req, err := NewRequest("PUT", ts.URL+v.path, body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| req.Header.Set("Expect", "100-continue") |
| req.ContentLength = int64(len(v.body)) |
| |
| resp, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| resp.Body.Close() |
| |
| sent := len(v.body) - body.Len() |
| if v.status != resp.StatusCode { |
| t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) |
| } |
| if v.sent != sent { |
| t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) |
| } |
| } |
| } |
| |
| func TestSOCKS5Proxy(t *testing.T) { |
| defer afterTest(t) |
| ch := make(chan string, 1) |
| l := newLocalListener(t) |
| defer l.Close() |
| defer close(ch) |
| proxy := func(t *testing.T) { |
| s, err := l.Accept() |
| if err != nil { |
| t.Errorf("socks5 proxy Accept(): %v", err) |
| return |
| } |
| defer s.Close() |
| var buf [22]byte |
| if _, err := io.ReadFull(s, buf[:3]); err != nil { |
| t.Errorf("socks5 proxy initial read: %v", err) |
| return |
| } |
| if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { |
| t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) |
| return |
| } |
| if _, err := s.Write([]byte{5, 0}); err != nil { |
| t.Errorf("socks5 proxy initial write: %v", err) |
| return |
| } |
| if _, err := io.ReadFull(s, buf[:4]); err != nil { |
| t.Errorf("socks5 proxy second read: %v", err) |
| return |
| } |
| if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { |
| t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) |
| return |
| } |
| var ipLen int |
| switch buf[3] { |
| case 1: |
| ipLen = net.IPv4len |
| case 4: |
| ipLen = net.IPv6len |
| default: |
| t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) |
| return |
| } |
| if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { |
| t.Errorf("socks5 proxy address read: %v", err) |
| return |
| } |
| ip := net.IP(buf[4 : ipLen+4]) |
| port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) |
| copy(buf[:3], []byte{5, 0, 0}) |
| if _, err := s.Write(buf[:ipLen+6]); err != nil { |
| t.Errorf("socks5 proxy connect write: %v", err) |
| return |
| } |
| ch <- fmt.Sprintf("proxy for %s:%d", ip, port) |
| |
| // Implement proxying. |
| targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) |
| targetConn, err := net.Dial("tcp", targetHost) |
| if err != nil { |
| t.Errorf("net.Dial failed") |
| return |
| } |
| go io.Copy(targetConn, s) |
| io.Copy(s, targetConn) // Wait for the client to close the socket. |
| targetConn.Close() |
| } |
| |
| pu, err := url.Parse("socks5://" + l.Addr().String()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| sentinelHeader := "X-Sentinel" |
| sentinelValue := "12345" |
| h := HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set(sentinelHeader, sentinelValue) |
| }) |
| for _, useTLS := range []bool{false, true} { |
| t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { |
| var ts *httptest.Server |
| if useTLS { |
| ts = httptest.NewTLSServer(h) |
| } else { |
| ts = httptest.NewServer(h) |
| } |
| go proxy(t) |
| c := ts.Client() |
| c.Transport.(*Transport).Proxy = ProxyURL(pu) |
| r, err := c.Head(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if r.Header.Get(sentinelHeader) != sentinelValue { |
| t.Errorf("Failed to retrieve sentinel value") |
| } |
| var got string |
| select { |
| case got = <-ch: |
| case <-time.After(5 * time.Second): |
| t.Fatal("timeout connecting to socks5 proxy") |
| } |
| ts.Close() |
| tsu, err := url.Parse(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| want := "proxy for " + tsu.Host |
| if got != want { |
| t.Errorf("got %q, want %q", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestTransportProxy(t *testing.T) { |
| defer afterTest(t) |
| testCases := []struct{ httpsSite, httpsProxy bool }{ |
| {false, false}, |
| {false, true}, |
| {true, false}, |
| {true, true}, |
| } |
| for _, testCase := range testCases { |
| httpsSite := testCase.httpsSite |
| httpsProxy := testCase.httpsProxy |
| t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { |
| siteCh := make(chan *Request, 1) |
| h1 := HandlerFunc(func(w ResponseWriter, r *Request) { |
| siteCh <- r |
| }) |
| proxyCh := make(chan *Request, 1) |
| h2 := HandlerFunc(func(w ResponseWriter, r *Request) { |
| proxyCh <- r |
| // Implement an entire CONNECT proxy |
| if r.Method == "CONNECT" { |
| hijacker, ok := w.(Hijacker) |
| if !ok { |
| t.Errorf("hijack not allowed") |
| return |
| } |
| clientConn, _, err := hijacker.Hijack() |
| if err != nil { |
| t.Errorf("hijacking failed") |
| return |
| } |
| res := &Response{ |
| StatusCode: StatusOK, |
| Proto: "HTTP/1.1", |
| ProtoMajor: 1, |
| ProtoMinor: 1, |
| Header: make(Header), |
| } |
| |
| targetConn, err := net.Dial("tcp", r.URL.Host) |
| if err != nil { |
| t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) |
| return |
| } |
| |
| if err := res.Write(clientConn); err != nil { |
| t.Errorf("Writing 200 OK failed: %v", err) |
| return |
| } |
| |
| go io.Copy(targetConn, clientConn) |
| go func() { |
| io.Copy(clientConn, targetConn) |
| targetConn.Close() |
| }() |
| } |
| }) |
| var ts *httptest.Server |
| if httpsSite { |
| ts = httptest.NewTLSServer(h1) |
| } else { |
| ts = httptest.NewServer(h1) |
| } |
| var proxy *httptest.Server |
| if httpsProxy { |
| proxy = httptest.NewTLSServer(h2) |
| } else { |
| proxy = httptest.NewServer(h2) |
| } |
| |
| pu, err := url.Parse(proxy.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // If neither server is HTTPS or both are, then c may be derived from either. |
| // If only one server is HTTPS, c must be derived from that server in order |
| // to ensure that it is configured to use the fake root CA from testcert.go. |
| c := proxy.Client() |
| if httpsSite { |
| c = ts.Client() |
| } |
| |
| c.Transport.(*Transport).Proxy = ProxyURL(pu) |
| if _, err := c.Head(ts.URL); err != nil { |
| t.Error(err) |
| } |
| var got *Request |
| select { |
| case got = <-proxyCh: |
| case <-time.After(5 * time.Second): |
| t.Fatal("timeout connecting to http proxy") |
| } |
| c.Transport.(*Transport).CloseIdleConnections() |
| ts.Close() |
| proxy.Close() |
| if httpsSite { |
| // First message should be a CONNECT, asking for a socket to the real server, |
| if got.Method != "CONNECT" { |
| t.Errorf("Wrong method for secure proxying: %q", got.Method) |
| } |
| gotHost := got.URL.Host |
| pu, err := url.Parse(ts.URL) |
| if err != nil { |
| t.Fatal("Invalid site URL") |
| } |
| if wantHost := pu.Host; gotHost != wantHost { |
| t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) |
| } |
| |
| // The next message on the channel should be from the site's server. |
| next := <-siteCh |
| if next.Method != "HEAD" { |
| t.Errorf("Wrong method at destination: %s", next.Method) |
| } |
| if nextURL := next.URL.String(); nextURL != "/" { |
| t.Errorf("Wrong URL at destination: %s", nextURL) |
| } |
| } else { |
| if got.Method != "HEAD" { |
| t.Errorf("Wrong method for destination: %q", got.Method) |
| } |
| gotURL := got.URL.String() |
| wantURL := ts.URL + "/" |
| if gotURL != wantURL { |
| t.Errorf("Got URL %q, want %q", gotURL, wantURL) |
| } |
| } |
| }) |
| } |
| } |
| |
| // Issue 28012: verify that the Transport closes its TCP connection to http proxies |
| // when they're slow to reply to HTTPS CONNECT responses. |
| func TestTransportProxyHTTPSConnectLeak(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| |
| ctx, cancel := context.WithCancel(context.Background()) |
| defer cancel() |
| |
| ln := newLocalListener(t) |
| defer ln.Close() |
| listenerDone := make(chan struct{}) |
| go func() { |
| defer close(listenerDone) |
| c, err := ln.Accept() |
| if err != nil { |
| t.Errorf("Accept: %v", err) |
| return |
| } |
| defer c.Close() |
| // Read the CONNECT request |
| br := bufio.NewReader(c) |
| cr, err := ReadRequest(br) |
| if err != nil { |
| t.Errorf("proxy server failed to read CONNECT request") |
| return |
| } |
| if cr.Method != "CONNECT" { |
| t.Errorf("unexpected method %q", cr.Method) |
| return |
| } |
| |
| // Now hang and never write a response; instead, cancel the request and wait |
| // for the client to close. |
| // (Prior to Issue 28012 being fixed, we never closed.) |
| cancel() |
| var buf [1]byte |
| _, err = br.Read(buf[:]) |
| if err != io.EOF { |
| t.Errorf("proxy server Read err = %v; want EOF", err) |
| } |
| return |
| }() |
| |
| c := &Client{ |
| Transport: &Transport{ |
| Proxy: func(*Request) (*url.URL, error) { |
| return url.Parse("http://" + ln.Addr().String()) |
| }, |
| }, |
| } |
| req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| _, err = c.Do(req) |
| if err == nil { |
| t.Errorf("unexpected Get success") |
| } |
| |
| // Wait unconditionally for the listener goroutine to exit: this should never |
| // hang, so if it does we want a full goroutine dump — and that's exactly what |
| // the testing package will give us when the test run times out. |
| <-listenerDone |
| } |
| |
| // Issue 16997: test transport dial preserves typed errors |
| func TestTransportDialPreservesNetOpProxyError(t *testing.T) { |
| defer afterTest(t) |
| |
| var errDial = errors.New("some dial error") |
| |
| tr := &Transport{ |
| Proxy: func(*Request) (*url.URL, error) { |
| return url.Parse("http://proxy.fake.tld/") |
| }, |
| Dial: func(string, string) (net.Conn, error) { |
| return nil, errDial |
| }, |
| } |
| defer tr.CloseIdleConnections() |
| |
| c := &Client{Transport: tr} |
| req, _ := NewRequest("GET", "http://fake.tld", nil) |
| res, err := c.Do(req) |
| if err == nil { |
| res.Body.Close() |
| t.Fatal("wanted a non-nil error") |
| } |
| |
| uerr, ok := err.(*url.Error) |
| if !ok { |
| t.Fatalf("got %T, want *url.Error", err) |
| } |
| oe, ok := uerr.Err.(*net.OpError) |
| if !ok { |
| t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) |
| } |
| want := &net.OpError{ |
| Op: "proxyconnect", |
| Net: "tcp", |
| Err: errDial, // original error, unwrapped. |
| } |
| if !reflect.DeepEqual(oe, want) { |
| t.Errorf("Got error %#v; want %#v", oe, want) |
| } |
| } |
| |
| // Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. |
| // |
| // (A bug caused dialConn to instead write the per-request Proxy-Authorization |
| // header through to the shared Header instance, introducing a data race.) |
| func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| |
| proxy := httptest.NewTLSServer(NotFoundHandler()) |
| defer proxy.Close() |
| c := proxy.Client() |
| |
| tr := c.Transport.(*Transport) |
| tr.Proxy = func(*Request) (*url.URL, error) { |
| u, _ := url.Parse(proxy.URL) |
| u.User = url.UserPassword("aladdin", "opensesame") |
| return u, nil |
| } |
| h := tr.ProxyConnectHeader |
| if h == nil { |
| h = make(Header) |
| } |
| tr.ProxyConnectHeader = h.Clone() |
| |
| req, err := NewRequest("GET", "https://golang.fake.tld/", nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| _, err = c.Do(req) |
| if err == nil { |
| t.Errorf("unexpected Get success") |
| } |
| |
| if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { |
| t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) |
| } |
| } |
| |
| // TestTransportGzipRecursive sends a gzip quine and checks that the |
| // client gets the same value back. This is more cute than anything, |
| // but checks that we don't recurse forever, and checks that |
| // Content-Encoding is removed. |
| func TestTransportGzipRecursive(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Content-Encoding", "gzip") |
| w.Write(rgz) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| body, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !bytes.Equal(body, rgz) { |
| t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", |
| body, rgz) |
| } |
| if g, e := res.Header.Get("Content-Encoding"), ""; g != e { |
| t.Fatalf("Content-Encoding = %q; want %q", g, e) |
| } |
| } |
| |
| // golang.org/issue/7750: request fails when server replies with |
| // a short gzip body |
| func TestTransportGzipShort(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Content-Encoding", "gzip") |
| w.Write([]byte{0x1f, 0x8b}) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| _, err = io.ReadAll(res.Body) |
| if err == nil { |
| t.Fatal("Expect an error from reading a body.") |
| } |
| if err != io.ErrUnexpectedEOF { |
| t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) |
| } |
| } |
| |
| // Wait until number of goroutines is no greater than nmax, or time out. |
| func waitNumGoroutine(nmax int) int { |
| nfinal := runtime.NumGoroutine() |
| for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { |
| time.Sleep(50 * time.Millisecond) |
| runtime.GC() |
| nfinal = runtime.NumGoroutine() |
| } |
| return nfinal |
| } |
| |
| // tests that persistent goroutine connections shut down when no longer desired. |
| func TestTransportPersistConnLeak(t *testing.T) { |
| // Not parallel: counts goroutines |
| defer afterTest(t) |
| |
| const numReq = 25 |
| gotReqCh := make(chan bool, numReq) |
| unblockCh := make(chan bool, numReq) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| gotReqCh <- true |
| <-unblockCh |
| w.Header().Set("Content-Length", "0") |
| w.WriteHeader(204) |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| n0 := runtime.NumGoroutine() |
| |
| didReqCh := make(chan bool, numReq) |
| failed := make(chan bool, numReq) |
| for i := 0; i < numReq; i++ { |
| go func() { |
| res, err := c.Get(ts.URL) |
| didReqCh <- true |
| if err != nil { |
| t.Logf("client fetch error: %v", err) |
| failed <- true |
| return |
| } |
| res.Body.Close() |
| }() |
| } |
| |
| // Wait for all goroutines to be stuck in the Handler. |
| for i := 0; i < numReq; i++ { |
| select { |
| case <-gotReqCh: |
| // ok |
| case <-failed: |
| // Not great but not what we are testing: |
| // sometimes an overloaded system will fail to make all the connections. |
| } |
| } |
| |
| nhigh := runtime.NumGoroutine() |
| |
| // Tell all handlers to unblock and reply. |
| close(unblockCh) |
| |
| // Wait for all HTTP clients to be done. |
| for i := 0; i < numReq; i++ { |
| <-didReqCh |
| } |
| |
| tr.CloseIdleConnections() |
| nfinal := waitNumGoroutine(n0 + 5) |
| |
| growth := nfinal - n0 |
| |
| // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. |
| // Previously we were leaking one per numReq. |
| if int(growth) > 5 { |
| t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) |
| t.Error("too many new goroutines") |
| } |
| } |
| |
| // golang.org/issue/4531: Transport leaks goroutines when |
| // request.ContentLength is explicitly short |
| func TestTransportPersistConnLeakShortBody(t *testing.T) { |
| // Not parallel: measures goroutines. |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| n0 := runtime.NumGoroutine() |
| body := []byte("Hello") |
| for i := 0; i < 20; i++ { |
| req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| req.ContentLength = int64(len(body) - 2) // explicitly short |
| _, err = c.Do(req) |
| if err == nil { |
| t.Fatal("Expect an error from writing too long of a body.") |
| } |
| } |
| nhigh := runtime.NumGoroutine() |
| tr.CloseIdleConnections() |
| nfinal := waitNumGoroutine(n0 + 5) |
| |
| growth := nfinal - n0 |
| |
| // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. |
| // Previously we were leaking one per numReq. |
| t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) |
| if int(growth) > 5 { |
| t.Error("too many new goroutines") |
| } |
| } |
| |
| // A countedConn is a net.Conn that decrements an atomic counter when finalized. |
| type countedConn struct { |
| net.Conn |
| } |
| |
| // A countingDialer dials connections and counts the number that remain reachable. |
| type countingDialer struct { |
| dialer net.Dialer |
| mu sync.Mutex |
| total, live int64 |
| } |
| |
| func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { |
| conn, err := d.dialer.DialContext(ctx, network, address) |
| if err != nil { |
| return nil, err |
| } |
| |
| counted := new(countedConn) |
| counted.Conn = conn |
| |
| d.mu.Lock() |
| defer d.mu.Unlock() |
| d.total++ |
| d.live++ |
| |
| runtime.SetFinalizer(counted, d.decrement) |
| return counted, nil |
| } |
| |
| func (d *countingDialer) decrement(*countedConn) { |
| d.mu.Lock() |
| defer d.mu.Unlock() |
| d.live-- |
| } |
| |
| func (d *countingDialer) Read() (total, live int64) { |
| d.mu.Lock() |
| defer d.mu.Unlock() |
| return d.total, d.live |
| } |
| |
| func TestTransportPersistConnLeakNeverIdle(t *testing.T) { |
| defer afterTest(t) |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| // Close every connection so that it cannot be kept alive. |
| conn, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Errorf("Hijack failed unexpectedly: %v", err) |
| return |
| } |
| conn.Close() |
| })) |
| defer ts.Close() |
| |
| var d countingDialer |
| c := ts.Client() |
| c.Transport.(*Transport).DialContext = d.DialContext |
| |
| body := []byte("Hello") |
| for i := 0; ; i++ { |
| total, live := d.Read() |
| if live < total { |
| break |
| } |
| if i >= 1<<12 { |
| t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) |
| } |
| |
| req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| _, err = c.Do(req) |
| if err == nil { |
| t.Fatal("expected broken connection") |
| } |
| |
| runtime.GC() |
| } |
| } |
| |
| type countedContext struct { |
| context.Context |
| } |
| |
| type contextCounter struct { |
| mu sync.Mutex |
| live int64 |
| } |
| |
| func (cc *contextCounter) Track(ctx context.Context) context.Context { |
| counted := new(countedContext) |
| counted.Context = ctx |
| cc.mu.Lock() |
| defer cc.mu.Unlock() |
| cc.live++ |
| runtime.SetFinalizer(counted, cc.decrement) |
| return counted |
| } |
| |
| func (cc *contextCounter) decrement(*countedContext) { |
| cc.mu.Lock() |
| defer cc.mu.Unlock() |
| cc.live-- |
| } |
| |
| func (cc *contextCounter) Read() (live int64) { |
| cc.mu.Lock() |
| defer cc.mu.Unlock() |
| return cc.live |
| } |
| |
| func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { |
| if runtime.Compiler == "gccgo" { |
| t.Skip("fails with conservative stack GC") |
| } |
| |
| defer afterTest(t) |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| runtime.Gosched() |
| w.WriteHeader(StatusOK) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| c.Transport.(*Transport).MaxConnsPerHost = 1 |
| |
| ctx := context.Background() |
| body := []byte("Hello") |
| doPosts := func(cc *contextCounter) { |
| var wg sync.WaitGroup |
| for n := 64; n > 0; n-- { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| |
| ctx := cc.Track(ctx) |
| req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) |
| if err != nil { |
| t.Error(err) |
| } |
| |
| _, err = c.Do(req.WithContext(ctx)) |
| if err != nil { |
| t.Errorf("Do failed with error: %v", err) |
| } |
| }() |
| } |
| wg.Wait() |
| } |
| |
| var initialCC contextCounter |
| doPosts(&initialCC) |
| |
| // flushCC exists only to put pressure on the GC to finalize the initialCC |
| // contexts: the flushCC allocations should eventually displace the initialCC |
| // allocations. |
| var flushCC contextCounter |
| for i := 0; ; i++ { |
| live := initialCC.Read() |
| if live == 0 { |
| break |
| } |
| if i >= 100 { |
| t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i) |
| } |
| doPosts(&flushCC) |
| runtime.GC() |
| } |
| } |
| |
| // This used to crash; https://golang.org/issue/3266 |
| func TestTransportIdleConnCrash(t *testing.T) { |
| defer afterTest(t) |
| var tr *Transport |
| |
| unblockCh := make(chan bool, 1) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| <-unblockCh |
| tr.CloseIdleConnections() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| tr = c.Transport.(*Transport) |
| |
| didreq := make(chan bool) |
| go func() { |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Error(err) |
| } else { |
| res.Body.Close() // returns idle conn |
| } |
| didreq <- true |
| }() |
| unblockCh <- true |
| <-didreq |
| } |
| |
| // Test that the transport doesn't close the TCP connection early, |
| // before the response body has been read. This was a regression |
| // which sadly lacked a triggering test. The large response body made |
| // the old race easier to trigger. |
| func TestIssue3644(t *testing.T) { |
| defer afterTest(t) |
| const numFoos = 5000 |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Connection", "close") |
| for i := 0; i < numFoos; i++ { |
| w.Write([]byte("foo ")) |
| } |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| bs, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if len(bs) != numFoos*len("foo ") { |
| t.Errorf("unexpected response length") |
| } |
| } |
| |
| // Test that a client receives a server's reply, even if the server doesn't read |
| // the entire request body. |
| func TestIssue3595(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| const deniedMsg = "sorry, denied." |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| Error(w, deniedMsg, StatusUnauthorized) |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) |
| if err != nil { |
| t.Errorf("Post: %v", err) |
| return |
| } |
| got, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("Body ReadAll: %v", err) |
| } |
| if !strings.Contains(string(got), deniedMsg) { |
| t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) |
| } |
| } |
| |
| // From https://golang.org/issue/4454 , |
| // "client fails to handle requests with no body and chunked encoding" |
| func TestChunkedNoContent(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.WriteHeader(StatusNoContent) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| for _, closeBody := range []bool{true, false} { |
| const n = 4 |
| for i := 1; i <= n; i++ { |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) |
| } else { |
| if closeBody { |
| res.Body.Close() |
| } |
| } |
| } |
| } |
| } |
| |
| func TestTransportConcurrency(t *testing.T) { |
| // Not parallel: uses global test hooks. |
| defer afterTest(t) |
| maxProcs, numReqs := 16, 500 |
| if testing.Short() { |
| maxProcs, numReqs = 4, 50 |
| } |
| defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| fmt.Fprintf(w, "%v", r.FormValue("echo")) |
| })) |
| defer ts.Close() |
| |
| var wg sync.WaitGroup |
| wg.Add(numReqs) |
| |
| // Due to the Transport's "socket late binding" (see |
| // idleConnCh in transport.go), the numReqs HTTP requests |
| // below can finish with a dial still outstanding. To keep |
| // the leak checker happy, keep track of pending dials and |
| // wait for them to finish (and be closed or returned to the |
| // idle pool) before we close idle connections. |
| SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) |
| defer SetPendingDialHooks(nil, nil) |
| |
| c := ts.Client() |
| reqs := make(chan string) |
| defer close(reqs) |
| |
| for i := 0; i < maxProcs*2; i++ { |
| go func() { |
| for req := range reqs { |
| res, err := c.Get(ts.URL + "/?echo=" + req) |
| if err != nil { |
| t.Errorf("error on req %s: %v", req, err) |
| wg.Done() |
| continue |
| } |
| all, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Errorf("read error on req %s: %v", req, err) |
| wg.Done() |
| continue |
| } |
| if string(all) != req { |
| t.Errorf("body of req %s = %q; want %q", req, all, req) |
| } |
| res.Body.Close() |
| wg.Done() |
| } |
| }() |
| } |
| for i := 0; i < numReqs; i++ { |
| reqs <- fmt.Sprintf("request-%d", i) |
| } |
| wg.Wait() |
| } |
| |
| func TestIssue4191_InfiniteGetTimeout(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| const debug = false |
| mux := NewServeMux() |
| mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { |
| io.Copy(w, neverEnding('a')) |
| }) |
| ts := httptest.NewServer(mux) |
| defer ts.Close() |
| timeout := 100 * time.Millisecond |
| |
| c := ts.Client() |
| c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { |
| conn, err := net.Dial(n, addr) |
| if err != nil { |
| return nil, err |
| } |
| conn.SetDeadline(time.Now().Add(timeout)) |
| if debug { |
| conn = NewLoggingConn("client", conn) |
| } |
| return conn, nil |
| } |
| |
| getFailed := false |
| nRuns := 5 |
| if testing.Short() { |
| nRuns = 1 |
| } |
| for i := 0; i < nRuns; i++ { |
| if debug { |
| println("run", i+1, "of", nRuns) |
| } |
| sres, err := c.Get(ts.URL + "/get") |
| if err != nil { |
| if !getFailed { |
| // Make the timeout longer, once. |
| getFailed = true |
| t.Logf("increasing timeout") |
| i-- |
| timeout *= 10 |
| continue |
| } |
| t.Errorf("Error issuing GET: %v", err) |
| break |
| } |
| _, err = io.Copy(io.Discard, sres.Body) |
| if err == nil { |
| t.Errorf("Unexpected successful copy") |
| break |
| } |
| } |
| if debug { |
| println("tests complete; waiting for handlers to finish") |
| } |
| } |
| |
| func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| const debug = false |
| mux := NewServeMux() |
| mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { |
| io.Copy(w, neverEnding('a')) |
| }) |
| mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { |
| defer r.Body.Close() |
| io.Copy(io.Discard, r.Body) |
| }) |
| ts := httptest.NewServer(mux) |
| timeout := 100 * time.Millisecond |
| |
| c := ts.Client() |
| c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { |
| conn, err := net.Dial(n, addr) |
| if err != nil { |
| return nil, err |
| } |
| conn.SetDeadline(time.Now().Add(timeout)) |
| if debug { |
| conn = NewLoggingConn("client", conn) |
| } |
| return conn, nil |
| } |
| |
| getFailed := false |
| nRuns := 5 |
| if testing.Short() { |
| nRuns = 1 |
| } |
| for i := 0; i < nRuns; i++ { |
| if debug { |
| println("run", i+1, "of", nRuns) |
| } |
| sres, err := c.Get(ts.URL + "/get") |
| if err != nil { |
| if !getFailed { |
| // Make the timeout longer, once. |
| getFailed = true |
| t.Logf("increasing timeout") |
| i-- |
| timeout *= 10 |
| continue |
| } |
| t.Errorf("Error issuing GET: %v", err) |
| break |
| } |
| req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) |
| _, err = c.Do(req) |
| if err == nil { |
| sres.Body.Close() |
| t.Errorf("Unexpected successful PUT") |
| break |
| } |
| sres.Body.Close() |
| } |
| if debug { |
| println("tests complete; waiting for handlers to finish") |
| } |
| ts.Close() |
| } |
| |
| func TestTransportResponseHeaderTimeout(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping timeout test in -short mode") |
| } |
| inHandler := make(chan bool, 1) |
| mux := NewServeMux() |
| mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { |
| inHandler <- true |
| }) |
| mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { |
| inHandler <- true |
| time.Sleep(2 * time.Second) |
| }) |
| ts := httptest.NewServer(mux) |
| defer ts.Close() |
| |
| c := ts.Client() |
| c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond |
| |
| tests := []struct { |
| path string |
| want int |
| wantErr string |
| }{ |
| {path: "/fast", want: 200}, |
| {path: "/slow", wantErr: "timeout awaiting response headers"}, |
| {path: "/fast", want: 200}, |
| } |
| for i, tt := range tests { |
| req, _ := NewRequest("GET", ts.URL+tt.path, nil) |
| req = req.WithT(t) |
| res, err := c.Do(req) |
| select { |
| case <-inHandler: |
| case <-time.After(5 * time.Second): |
| t.Errorf("never entered handler for test index %d, %s", i, tt.path) |
| continue |
| } |
| if err != nil { |
| uerr, ok := err.(*url.Error) |
| if !ok { |
| t.Errorf("error is not an url.Error; got: %#v", err) |
| continue |
| } |
| nerr, ok := uerr.Err.(net.Error) |
| if !ok { |
| t.Errorf("error does not satisfy net.Error interface; got: %#v", err) |
| continue |
| } |
| if !nerr.Timeout() { |
| t.Errorf("want timeout error; got: %q", nerr) |
| continue |
| } |
| if strings.Contains(err.Error(), tt.wantErr) { |
| continue |
| } |
| t.Errorf("%d. unexpected error: %v", i, err) |
| continue |
| } |
| if tt.wantErr != "" { |
| t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) |
| continue |
| } |
| if res.StatusCode != tt.want { |
| t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) |
| } |
| } |
| } |
| |
| func TestTransportCancelRequest(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping test in -short mode") |
| } |
| unblockc := make(chan bool) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| fmt.Fprintf(w, "Hello") |
| w.(Flusher).Flush() // send headers and some body |
| <-unblockc |
| })) |
| defer ts.Close() |
| defer close(unblockc) |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| req, _ := NewRequest("GET", ts.URL, nil) |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| go func() { |
| time.Sleep(1 * time.Second) |
| tr.CancelRequest(req) |
| }() |
| t0 := time.Now() |
| body, err := io.ReadAll(res.Body) |
| d := time.Since(t0) |
| |
| if err != ExportErrRequestCanceled { |
| t.Errorf("Body.Read error = %v; want errRequestCanceled", err) |
| } |
| if string(body) != "Hello" { |
| t.Errorf("Body = %q; want Hello", body) |
| } |
| if d < 500*time.Millisecond { |
| t.Errorf("expected ~1 second delay; got %v", d) |
| } |
| // Verify no outstanding requests after readLoop/writeLoop |
| // goroutines shut down. |
| for tries := 5; tries > 0; tries-- { |
| n := tr.NumPendingRequestsForTesting() |
| if n == 0 { |
| break |
| } |
| time.Sleep(100 * time.Millisecond) |
| if tries == 1 { |
| t.Errorf("pending requests = %d; want 0", n) |
| } |
| } |
| } |
| |
| func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { |
| setParallel(t) |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping test in -short mode") |
| } |
| unblockc := make(chan bool) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| <-unblockc |
| })) |
| defer ts.Close() |
| defer close(unblockc) |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| donec := make(chan bool) |
| req, _ := NewRequest("GET", ts.URL, body) |
| go func() { |
| defer close(donec) |
| c.Do(req) |
| }() |
| start := time.Now() |
| timeout := 10 * time.Second |
| for time.Since(start) < timeout { |
| time.Sleep(100 * time.Millisecond) |
| tr.CancelRequest(req) |
| select { |
| case <-donec: |
| return |
| default: |
| } |
| } |
| t.Errorf("Do of canceled request has not returned after %v", timeout) |
| } |
| |
| func TestTransportCancelRequestInDo(t *testing.T) { |
| testTransportCancelRequestInDo(t, nil) |
| } |
| |
| func TestTransportCancelRequestWithBodyInDo(t *testing.T) { |
| testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) |
| } |
| |
| func TestTransportCancelRequestInDial(t *testing.T) { |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping test in -short mode") |
| } |
| var logbuf bytes.Buffer |
| eventLog := log.New(&logbuf, "", 0) |
| |
| unblockDial := make(chan bool) |
| defer close(unblockDial) |
| |
| inDial := make(chan bool) |
| tr := &Transport{ |
| Dial: func(network, addr string) (net.Conn, error) { |
| eventLog.Println("dial: blocking") |
| if !<-inDial { |
| return nil, errors.New("main Test goroutine exited") |
| } |
| <-unblockDial |
| return nil, errors.New("nope") |
| }, |
| } |
| cl := &Client{Transport: tr} |
| gotres := make(chan bool) |
| req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) |
| go func() { |
| _, err := cl.Do(req) |
| eventLog.Printf("Get = %v", err) |
| gotres <- true |
| }() |
| |
| select { |
| case inDial <- true: |
| case <-time.After(5 * time.Second): |
| close(inDial) |
| t.Fatal("timeout; never saw blocking dial") |
| } |
| |
| eventLog.Printf("canceling") |
| tr.CancelRequest(req) |
| tr.CancelRequest(req) // used to panic on second call |
| |
| select { |
| case <-gotres: |
| case <-time.After(5 * time.Second): |
| panic("hang. events are: " + logbuf.String()) |
| } |
| |
| got := logbuf.String() |
| want := `dial: blocking |
| canceling |
| Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection |
| ` |
| if got != want { |
| t.Errorf("Got events:\n%s\nWant:\n%s", got, want) |
| } |
| } |
| |
| func TestCancelRequestWithChannel(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping test in -short mode") |
| } |
| unblockc := make(chan bool) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| fmt.Fprintf(w, "Hello") |
| w.(Flusher).Flush() // send headers and some body |
| <-unblockc |
| })) |
| defer ts.Close() |
| defer close(unblockc) |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| req, _ := NewRequest("GET", ts.URL, nil) |
| ch := make(chan struct{}) |
| req.Cancel = ch |
| |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| go func() { |
| time.Sleep(1 * time.Second) |
| close(ch) |
| }() |
| t0 := time.Now() |
| body, err := io.ReadAll(res.Body) |
| d := time.Since(t0) |
| |
| if err != ExportErrRequestCanceled { |
| t.Errorf("Body.Read error = %v; want errRequestCanceled", err) |
| } |
| if string(body) != "Hello" { |
| t.Errorf("Body = %q; want Hello", body) |
| } |
| if d < 500*time.Millisecond { |
| t.Errorf("expected ~1 second delay; got %v", d) |
| } |
| // Verify no outstanding requests after readLoop/writeLoop |
| // goroutines shut down. |
| for tries := 5; tries > 0; tries-- { |
| n := tr.NumPendingRequestsForTesting() |
| if n == 0 { |
| break |
| } |
| time.Sleep(100 * time.Millisecond) |
| if tries == 1 { |
| t.Errorf("pending requests = %d; want 0", n) |
| } |
| } |
| } |
| |
| func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { |
| testCancelRequestWithChannelBeforeDo(t, false) |
| } |
| func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { |
| testCancelRequestWithChannelBeforeDo(t, true) |
| } |
| func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { |
| setParallel(t) |
| defer afterTest(t) |
| unblockc := make(chan bool) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| <-unblockc |
| })) |
| defer ts.Close() |
| defer close(unblockc) |
| |
| c := ts.Client() |
| |
| req, _ := NewRequest("GET", ts.URL, nil) |
| if withCtx { |
| ctx, cancel := context.WithCancel(context.Background()) |
| cancel() |
| req = req.WithContext(ctx) |
| } else { |
| ch := make(chan struct{}) |
| req.Cancel = ch |
| close(ch) |
| } |
| |
| _, err := c.Do(req) |
| if ue, ok := err.(*url.Error); ok { |
| err = ue.Err |
| } |
| if withCtx { |
| if err != context.Canceled { |
| t.Errorf("Do error = %v; want %v", err, context.Canceled) |
| } |
| } else { |
| if err == nil || !strings.Contains(err.Error(), "canceled") { |
| t.Errorf("Do error = %v; want cancellation", err) |
| } |
| } |
| } |
| |
| // Issue 11020. The returned error message should be errRequestCanceled |
| func TestTransportCancelBeforeResponseHeaders(t *testing.T) { |
| defer afterTest(t) |
| |
| serverConnCh := make(chan net.Conn, 1) |
| tr := &Transport{ |
| Dial: func(network, addr string) (net.Conn, error) { |
| cc, sc := net.Pipe() |
| serverConnCh <- sc |
| return cc, nil |
| }, |
| } |
| defer tr.CloseIdleConnections() |
| errc := make(chan error, 1) |
| req, _ := NewRequest("GET", "http://example.com/", nil) |
| go func() { |
| _, err := tr.RoundTrip(req) |
| errc <- err |
| }() |
| |
| sc := <-serverConnCh |
| verb := make([]byte, 3) |
| if _, err := io.ReadFull(sc, verb); err != nil { |
| t.Errorf("Error reading HTTP verb from server: %v", err) |
| } |
| if string(verb) != "GET" { |
| t.Errorf("server received %q; want GET", verb) |
| } |
| defer sc.Close() |
| |
| tr.CancelRequest(req) |
| |
| err := <-errc |
| if err == nil { |
| t.Fatalf("unexpected success from RoundTrip") |
| } |
| if err != ExportErrRequestCanceled { |
| t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) |
| } |
| } |
| |
| // golang.org/issue/3672 -- Client can't close HTTP stream |
| // Calling Close on a Response.Body used to just read until EOF. |
| // Now it actually closes the TCP connection. |
| func TestTransportCloseResponseBody(t *testing.T) { |
| defer afterTest(t) |
| writeErr := make(chan error, 1) |
| msg := []byte("young\n") |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| for { |
| _, err := w.Write(msg) |
| if err != nil { |
| writeErr <- err |
| return |
| } |
| w.(Flusher).Flush() |
| } |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| req, _ := NewRequest("GET", ts.URL, nil) |
| defer tr.CancelRequest(req) |
| |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| const repeats = 3 |
| buf := make([]byte, len(msg)*repeats) |
| want := bytes.Repeat(msg, repeats) |
| |
| _, err = io.ReadFull(res.Body, buf) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !bytes.Equal(buf, want) { |
| t.Fatalf("read %q; want %q", buf, want) |
| } |
| didClose := make(chan error, 1) |
| go func() { |
| didClose <- res.Body.Close() |
| }() |
| select { |
| case err := <-didClose: |
| if err != nil { |
| t.Errorf("Close = %v", err) |
| } |
| case <-time.After(10 * time.Second): |
| t.Fatal("too long waiting for close") |
| } |
| select { |
| case err := <-writeErr: |
| if err == nil { |
| t.Errorf("expected non-nil write error") |
| } |
| case <-time.After(10 * time.Second): |
| t.Fatal("too long waiting for write error") |
| } |
| } |
| |
| type fooProto struct{} |
| |
| func (fooProto) RoundTrip(req *Request) (*Response, error) { |
| res := &Response{ |
| Status: "200 OK", |
| StatusCode: 200, |
| Header: make(Header), |
| Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), |
| } |
| return res, nil |
| } |
| |
| func TestTransportAltProto(t *testing.T) { |
| defer afterTest(t) |
| tr := &Transport{} |
| c := &Client{Transport: tr} |
| tr.RegisterProtocol("foo", fooProto{}) |
| res, err := c.Get("foo://bar.com/path") |
| if err != nil { |
| t.Fatal(err) |
| } |
| bodyb, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| body := string(bodyb) |
| if e := "You wanted foo://bar.com/path"; body != e { |
| t.Errorf("got response %q, want %q", body, e) |
| } |
| } |
| |
| func TestTransportNoHost(t *testing.T) { |
| defer afterTest(t) |
| tr := &Transport{} |
| _, err := tr.RoundTrip(&Request{ |
| Header: make(Header), |
| URL: &url.URL{ |
| Scheme: "http", |
| }, |
| }) |
| want := "http: no Host in request URL" |
| if got := fmt.Sprint(err); got != want { |
| t.Errorf("error = %v; want %q", err, want) |
| } |
| } |
| |
| // Issue 13311 |
| func TestTransportEmptyMethod(t *testing.T) { |
| req, _ := NewRequest("GET", "http://foo.com/", nil) |
| req.Method = "" // docs say "For client requests an empty string means GET" |
| got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !strings.Contains(string(got), "GET ") { |
| t.Fatalf("expected substring 'GET '; got: %s", got) |
| } |
| } |
| |
| func TestTransportSocketLateBinding(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| |
| mux := NewServeMux() |
| fooGate := make(chan bool, 1) |
| mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { |
| w.Header().Set("foo-ipport", r.RemoteAddr) |
| w.(Flusher).Flush() |
| <-fooGate |
| }) |
| mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { |
| w.Header().Set("bar-ipport", r.RemoteAddr) |
| }) |
| ts := httptest.NewServer(mux) |
| defer ts.Close() |
| |
| dialGate := make(chan bool, 1) |
| c := ts.Client() |
| c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { |
| if <-dialGate { |
| return net.Dial(n, addr) |
| } |
| return nil, errors.New("manually closed") |
| } |
| |
| dialGate <- true // only allow one dial |
| fooRes, err := c.Get(ts.URL + "/foo") |
| if err != nil { |
| t.Fatal(err) |
| } |
| fooAddr := fooRes.Header.Get("foo-ipport") |
| if fooAddr == "" { |
| t.Fatal("No addr on /foo request") |
| } |
| time.AfterFunc(200*time.Millisecond, func() { |
| // let the foo response finish so we can use its |
| // connection for /bar |
| fooGate <- true |
| io.Copy(io.Discard, fooRes.Body) |
| fooRes.Body.Close() |
| }) |
| |
| barRes, err := c.Get(ts.URL + "/bar") |
| if err != nil { |
| t.Fatal(err) |
| } |
| barAddr := barRes.Header.Get("bar-ipport") |
| if barAddr != fooAddr { |
| t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) |
| } |
| barRes.Body.Close() |
| dialGate <- false |
| } |
| |
| // Issue 2184 |
| func TestTransportReading100Continue(t *testing.T) { |
| defer afterTest(t) |
| |
| const numReqs = 5 |
| reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } |
| reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } |
| |
| send100Response := func(w *io.PipeWriter, r *io.PipeReader) { |
| defer w.Close() |
| defer r.Close() |
| br := bufio.NewReader(r) |
| n := 0 |
| for { |
| n++ |
| req, err := ReadRequest(br) |
| if err == io.EOF { |
| return |
| } |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| slurp, err := io.ReadAll(req.Body) |
| if err != nil { |
| t.Errorf("Server request body slurp: %v", err) |
| return |
| } |
| id := req.Header.Get("Request-Id") |
| resCode := req.Header.Get("X-Want-Response-Code") |
| if resCode == "" { |
| resCode = "100 Continue" |
| if string(slurp) != reqBody(n) { |
| t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) |
| } |
| } |
| body := fmt.Sprintf("Response number %d", n) |
| v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s |
| Date: Thu, 28 Feb 2013 17:55:41 GMT |
| |
| HTTP/1.1 200 OK |
| Content-Type: text/html |
| Echo-Request-Id: %s |
| Content-Length: %d |
| |
| %s`, resCode, id, len(body), body), "\n", "\r\n", -1)) |
| w.Write(v) |
| if id == reqID(numReqs) { |
| return |
| } |
| } |
| |
| } |
| |
| tr := &Transport{ |
| Dial: func(n, addr string) (net.Conn, error) { |
| sr, sw := io.Pipe() // server read/write |
| cr, cw := io.Pipe() // client read/write |
| conn := &rwTestConn{ |
| Reader: cr, |
| Writer: sw, |
| closeFunc: func() error { |
| sw.Close() |
| cw.Close() |
| return nil |
| }, |
| } |
| go send100Response(cw, sr) |
| return conn, nil |
| }, |
| DisableKeepAlives: false, |
| } |
| defer tr.CloseIdleConnections() |
| c := &Client{Transport: tr} |
| |
| testResponse := func(req *Request, name string, wantCode int) { |
| t.Helper() |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatalf("%s: Do: %v", name, err) |
| } |
| if res.StatusCode != wantCode { |
| t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) |
| } |
| if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { |
| t.Errorf("%s: response id %q != request id %q", name, idBack, id) |
| } |
| _, err = io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatalf("%s: Slurp error: %v", name, err) |
| } |
| } |
| |
| // Few 100 responses, making sure we're not off-by-one. |
| for i := 1; i <= numReqs; i++ { |
| req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) |
| req.Header.Set("Request-Id", reqID(i)) |
| testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) |
| } |
| } |
| |
| // Issue 17739: the HTTP client must ignore any unknown 1xx |
| // informational responses before the actual response. |
| func TestTransportIgnore1xxResponses(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, buf, _ := w.(Hijacker).Hijack() |
| buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) |
| buf.Flush() |
| conn.Close() |
| })) |
| defer cst.close() |
| cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway |
| |
| var got bytes.Buffer |
| |
| req, _ := NewRequest("GET", cst.ts.URL, nil) |
| req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ |
| Got1xxResponse: func(code int, header textproto.MIMEHeader) error { |
| fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) |
| return nil |
| }, |
| })) |
| res, err := cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| |
| res.Write(&got) |
| want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" |
| if got.String() != want { |
| t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) |
| } |
| } |
| |
| func TestTransportLimits1xxResponses(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, buf, _ := w.(Hijacker).Hijack() |
| for i := 0; i < 10; i++ { |
| buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) |
| } |
| buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) |
| buf.Flush() |
| conn.Close() |
| })) |
| defer cst.close() |
| cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway |
| |
| res, err := cst.c.Get(cst.ts.URL) |
| if res != nil { |
| defer res.Body.Close() |
| } |
| got := fmt.Sprint(err) |
| wantSub := "too many 1xx informational responses" |
| if !strings.Contains(got, wantSub) { |
| t.Errorf("Get error = %v; want substring %q", err, wantSub) |
| } |
| } |
| |
| // Issue 26161: the HTTP client must treat 101 responses |
| // as the final response. |
| func TestTransportTreat101Terminal(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, buf, _ := w.(Hijacker).Hijack() |
| buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) |
| buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) |
| buf.Flush() |
| conn.Close() |
| })) |
| defer cst.close() |
| res, err := cst.c.Get(cst.ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| if res.StatusCode != StatusSwitchingProtocols { |
| t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) |
| } |
| } |
| |
| type proxyFromEnvTest struct { |
| req string // URL to fetch; blank means "http://example.com" |
| |
| env string // HTTP_PROXY |
| httpsenv string // HTTPS_PROXY |
| noenv string // NO_PROXY |
| reqmeth string // REQUEST_METHOD |
| |
| want string |
| wanterr error |
| } |
| |
| func (t proxyFromEnvTest) String() string { |
| var buf bytes.Buffer |
| space := func() { |
| if buf.Len() > 0 { |
| buf.WriteByte(' ') |
| } |
| } |
| if t.env != "" { |
| fmt.Fprintf(&buf, "http_proxy=%q", t.env) |
| } |
| if t.httpsenv != "" { |
| space() |
| fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) |
| } |
| if t.noenv != "" { |
| space() |
| fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) |
| } |
| if t.reqmeth != "" { |
| space() |
| fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) |
| } |
| req := "http://example.com" |
| if t.req != "" { |
| req = t.req |
| } |
| space() |
| fmt.Fprintf(&buf, "req=%q", req) |
| return strings.TrimSpace(buf.String()) |
| } |
| |
| var proxyFromEnvTests = []proxyFromEnvTest{ |
| {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, |
| {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, |
| {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, |
| {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, |
| {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, |
| {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, |
| {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, |
| |
| // Don't use secure for http |
| {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, |
| // Use secure for https. |
| {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, |
| {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, |
| |
| // Issue 16405: don't use HTTP_PROXY in a CGI environment, |
| // where HTTP_PROXY can be attacker-controlled. |
| {env: "http://10.1.2.3:8080", reqmeth: "POST", |
| want: "<nil>", |
| wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, |
| |
| {want: "<nil>"}, |
| |
| {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, |
| {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, |
| {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, |
| {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"}, |
| {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, |
| } |
| |
| func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) { |
| t.Helper() |
| reqURL := tt.req |
| if reqURL == "" { |
| reqURL = "http://example.com" |
| } |
| req, _ := NewRequest("GET", reqURL, nil) |
| url, err := proxyForRequest(req) |
| if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { |
| t.Errorf("%v: got error = %q, want %q", tt, g, e) |
| return |
| } |
| if got := fmt.Sprintf("%s", url); got != tt.want { |
| t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) |
| } |
| } |
| |
| func TestProxyFromEnvironment(t *testing.T) { |
| ResetProxyEnv() |
| defer ResetProxyEnv() |
| for _, tt := range proxyFromEnvTests { |
| testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { |
| os.Setenv("HTTP_PROXY", tt.env) |
| os.Setenv("HTTPS_PROXY", tt.httpsenv) |
| os.Setenv("NO_PROXY", tt.noenv) |
| os.Setenv("REQUEST_METHOD", tt.reqmeth) |
| ResetCachedEnvironment() |
| return ProxyFromEnvironment(req) |
| }) |
| } |
| } |
| |
| func TestProxyFromEnvironmentLowerCase(t *testing.T) { |
| ResetProxyEnv() |
| defer ResetProxyEnv() |
| for _, tt := range proxyFromEnvTests { |
| testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { |
| os.Setenv("http_proxy", tt.env) |
| os.Setenv("https_proxy", tt.httpsenv) |
| os.Setenv("no_proxy", tt.noenv) |
| os.Setenv("REQUEST_METHOD", tt.reqmeth) |
| ResetCachedEnvironment() |
| return ProxyFromEnvironment(req) |
| }) |
| } |
| } |
| |
| func TestIdleConnChannelLeak(t *testing.T) { |
| // Not parallel: uses global test hooks. |
| var mu sync.Mutex |
| var n int |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| mu.Lock() |
| n++ |
| mu.Unlock() |
| })) |
| defer ts.Close() |
| |
| const nReqs = 5 |
| didRead := make(chan bool, nReqs) |
| SetReadLoopBeforeNextReadHook(func() { didRead <- true }) |
| defer SetReadLoopBeforeNextReadHook(nil) |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| tr.Dial = func(netw, addr string) (net.Conn, error) { |
| return net.Dial(netw, ts.Listener.Addr().String()) |
| } |
| |
| // First, without keep-alives. |
| for _, disableKeep := range []bool{true, false} { |
| tr.DisableKeepAlives = disableKeep |
| for i := 0; i < nReqs; i++ { |
| _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| // Note: no res.Body.Close is needed here, since the |
| // response Content-Length is zero. Perhaps the test |
| // should be more explicit and use a HEAD, but tests |
| // elsewhere guarantee that zero byte responses generate |
| // a "Content-Length: 0" instead of chunking. |
| } |
| |
| // At this point, each of the 5 Transport.readLoop goroutines |
| // are scheduling noting that there are no response bodies (see |
| // earlier comment), and are then calling putIdleConn, which |
| // decrements this count. Usually that happens quickly, which is |
| // why this test has seemed to work for ages. But it's still |
| // racey: we have wait for them to finish first. See Issue 10427 |
| for i := 0; i < nReqs; i++ { |
| <-didRead |
| } |
| |
| if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { |
| t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) |
| } |
| } |
| } |
| |
| // Verify the status quo: that the Client.Post function coerces its |
| // body into a ReadCloser if it's a Closer, and that the Transport |
| // then closes it. |
| func TestTransportClosesRequestBody(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| io.Copy(io.Discard, r.Body) |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| |
| closes := 0 |
| |
| res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| if closes != 1 { |
| t.Errorf("closes = %d; want 1", closes) |
| } |
| } |
| |
| func TestTransportTLSHandshakeTimeout(t *testing.T) { |
| defer afterTest(t) |
| if testing.Short() { |
| t.Skip("skipping in short mode") |
| } |
| ln := newLocalListener(t) |
| defer ln.Close() |
| testdonec := make(chan struct{}) |
| defer close(testdonec) |
| |
| go func() { |
| c, err := ln.Accept() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| <-testdonec |
| c.Close() |
| }() |
| |
| getdonec := make(chan struct{}) |
| go func() { |
| defer close(getdonec) |
| tr := &Transport{ |
| Dial: func(_, _ string) (net.Conn, error) { |
| return net.Dial("tcp", ln.Addr().String()) |
| }, |
| TLSHandshakeTimeout: 250 * time.Millisecond, |
| } |
| cl := &Client{Transport: tr} |
| _, err := cl.Get("https://dummy.tld/") |
| if err == nil { |
| t.Error("expected error") |
| return |
| } |
| ue, ok := err.(*url.Error) |
| if !ok { |
| t.Errorf("expected url.Error; got %#v", err) |
| return |
| } |
| ne, ok := ue.Err.(net.Error) |
| if !ok { |
| t.Errorf("expected net.Error; got %#v", err) |
| return |
| } |
| if !ne.Timeout() { |
| t.Errorf("expected timeout error; got %v", err) |
| } |
| if !strings.Contains(err.Error(), "handshake timeout") { |
| t.Errorf("expected 'handshake timeout' in error; got %v", err) |
| } |
| }() |
| select { |
| case <-getdonec: |
| case <-time.After(5 * time.Second): |
| t.Error("test timeout; TLS handshake hung?") |
| } |
| } |
| |
| // Trying to repro golang.org/issue/3514 |
| func TestTLSServerClosesConnection(t *testing.T) { |
| defer afterTest(t) |
| |
| closedc := make(chan bool, 1) |
| ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if strings.Contains(r.URL.Path, "/keep-alive-then-die") { |
| conn, _, _ := w.(Hijacker).Hijack() |
| conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) |
| conn.Close() |
| closedc <- true |
| return |
| } |
| fmt.Fprintf(w, "hello") |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| |
| var nSuccess = 0 |
| var errs []error |
| const trials = 20 |
| for i := 0; i < trials; i++ { |
| tr.CloseIdleConnections() |
| res, err := c.Get(ts.URL + "/keep-alive-then-die") |
| if err != nil { |
| t.Fatal(err) |
| } |
| <-closedc |
| slurp, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(slurp) != "foo" { |
| t.Errorf("Got %q, want foo", slurp) |
| } |
| |
| // Now try again and see if we successfully |
| // pick a new connection. |
| res, err = c.Get(ts.URL + "/") |
| if err != nil { |
| errs = append(errs, err) |
| continue |
| } |
| slurp, err = io.ReadAll(res.Body) |
| if err != nil { |
| errs = append(errs, err) |
| continue |
| } |
| nSuccess++ |
| } |
| if nSuccess > 0 { |
| t.Logf("successes = %d of %d", nSuccess, trials) |
| } else { |
| t.Errorf("All runs failed:") |
| } |
| for _, err := range errs { |
| t.Logf(" err: %v", err) |
| } |
| } |
| |
| // byteFromChanReader is an io.Reader that reads a single byte at a |
| // time from the channel. When the channel is closed, the reader |
| // returns io.EOF. |
| type byteFromChanReader chan byte |
| |
| func (c byteFromChanReader) Read(p []byte) (n int, err error) { |
| if len(p) == 0 { |
| return |
| } |
| b, ok := <-c |
| if !ok { |
| return 0, io.EOF |
| } |
| p[0] = b |
| return 1, nil |
| } |
| |
| // Verifies that the Transport doesn't reuse a connection in the case |
| // where the server replies before the request has been fully |
| // written. We still honor that reply (see TestIssue3595), but don't |
| // send future requests on the connection because it's then in a |
| // questionable state. |
| // golang.org/issue/7569 |
| func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| var sconn struct { |
| sync.Mutex |
| c net.Conn |
| } |
| var getOkay bool |
| closeConn := func() { |
| sconn.Lock() |
| defer sconn.Unlock() |
| if sconn.c != nil { |
| sconn.c.Close() |
| sconn.c = nil |
| if !getOkay { |
| t.Logf("Closed server connection") |
| } |
| } |
| } |
| defer closeConn() |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method == "GET" { |
| io.WriteString(w, "bar") |
| return |
| } |
| conn, _, _ := w.(Hijacker).Hijack() |
| sconn.Lock() |
| sconn.c = conn |
| sconn.Unlock() |
| conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive |
| go io.Copy(io.Discard, conn) |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| const bodySize = 256 << 10 |
| finalBit := make(byteFromChanReader, 1) |
| req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) |
| req.ContentLength = bodySize |
| res, err := c.Do(req) |
| if err := wantBody(res, err, "foo"); err != nil { |
| t.Errorf("POST response: %v", err) |
| } |
| donec := make(chan bool) |
| go func() { |
| defer close(donec) |
| res, err = c.Get(ts.URL) |
| if err := wantBody(res, err, "bar"); err != nil { |
| t.Errorf("GET response: %v", err) |
| return |
| } |
| getOkay = true // suppress test noise |
| }() |
| time.AfterFunc(5*time.Second, closeConn) |
| select { |
| case <-donec: |
| finalBit <- 'x' // unblock the writeloop of the first Post |
| close(finalBit) |
| case <-time.After(7 * time.Second): |
| t.Fatal("timeout waiting for GET request to finish") |
| } |
| } |
| |
| // Tests that we don't leak Transport persistConn.readLoop goroutines |
| // when a server hangs up immediately after saying it would keep-alive. |
| func TestTransportIssue10457(t *testing.T) { |
| defer afterTest(t) // used to fail in goroutine leak check |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| // Send a response with no body, keep-alive |
| // (implicit), and then lie and immediately close the |
| // connection. This forces the Transport's readLoop to |
| // immediately Peek an io.EOF and get to the point |
| // that used to hang. |
| conn, _, _ := w.(Hijacker).Hijack() |
| conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive |
| conn.Close() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatalf("Get: %v", err) |
| } |
| defer res.Body.Close() |
| |
| // Just a sanity check that we at least get the response. The real |
| // test here is that the "defer afterTest" above doesn't find any |
| // leaked goroutines. |
| if got, want := res.Header.Get("Foo"), "Bar"; got != want { |
| t.Errorf("Foo header = %q; want %q", got, want) |
| } |
| } |
| |
| type closerFunc func() error |
| |
| func (f closerFunc) Close() error { return f() } |
| |
| type writerFuncConn struct { |
| net.Conn |
| write func(p []byte) (n int, err error) |
| } |
| |
| func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } |
| |
| // Issues 4677, 18241, and 17844. If we try to reuse a connection that the |
| // server is in the process of closing, we may end up successfully writing out |
| // our request (or a portion of our request) only to find a connection error |
| // when we try to read from (or finish writing to) the socket. |
| // |
| // NOTE: we resend a request only if: |
| // - we reused a keep-alive connection |
| // - we haven't yet received any header data |
| // - either we wrote no bytes to the server, or the request is idempotent |
| // This automatically prevents an infinite resend loop because we'll run out of |
| // the cached keep-alive connections eventually. |
| func TestRetryRequestsOnError(t *testing.T) { |
| newRequest := func(method, urlStr string, body io.Reader) *Request { |
| req, err := NewRequest(method, urlStr, body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return req |
| } |
| |
| testCases := []struct { |
| name string |
| failureN int |
| failureErr error |
| // Note that we can't just re-use the Request object across calls to c.Do |
| // because we need to rewind Body between calls. (GetBody is only used to |
| // rewind Body on failure and redirects, not just because it's done.) |
| req func() *Request |
| reqString string |
| }{ |
| { |
| name: "IdempotentNoBodySomeWritten", |
| // Believe that we've written some bytes to the server, so we know we're |
| // not just in the "retry when no bytes sent" case". |
| failureN: 1, |
| // Use the specific error that shouldRetryRequest looks for with idempotent requests. |
| failureErr: ExportErrServerClosedIdle, |
| req: func() *Request { |
| return newRequest("GET", "http://fake.golang", nil) |
| }, |
| reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, |
| }, |
| { |
| name: "IdempotentGetBodySomeWritten", |
| // Believe that we've written some bytes to the server, so we know we're |
| // not just in the "retry when no bytes sent" case". |
| failureN: 1, |
| // Use the specific error that shouldRetryRequest looks for with idempotent requests. |
| failureErr: ExportErrServerClosedIdle, |
| req: func() *Request { |
| return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n")) |
| }, |
| reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, |
| }, |
| { |
| name: "NothingWrittenNoBody", |
| // It's key that we return 0 here -- that's what enables Transport to know |
| // that nothing was written, even though this is a non-idempotent request. |
| failureN: 0, |
| failureErr: errors.New("second write fails"), |
| req: func() *Request { |
| return newRequest("DELETE", "http://fake.golang", nil) |
| }, |
| reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, |
| }, |
| { |
| name: "NothingWrittenGetBody", |
| // It's key that we return 0 here -- that's what enables Transport to know |
| // that nothing was written, even though this is a non-idempotent request. |
| failureN: 0, |
| failureErr: errors.New("second write fails"), |
| // Note that NewRequest will set up GetBody for strings.Reader, which is |
| // required for the retry to occur |
| req: func() *Request { |
| return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n")) |
| }, |
| reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, |
| }, |
| } |
| |
| for _, tc := range testCases { |
| t.Run(tc.name, func(t *testing.T) { |
| defer afterTest(t) |
| |
| var ( |
| mu sync.Mutex |
| logbuf bytes.Buffer |
| ) |
| logf := func(format string, args ...any) { |
| mu.Lock() |
| defer mu.Unlock() |
| fmt.Fprintf(&logbuf, format, args...) |
| logbuf.WriteByte('\n') |
| } |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| logf("Handler") |
| w.Header().Set("X-Status", "ok") |
| })) |
| defer ts.Close() |
| |
| var writeNumAtomic int32 |
| c := ts.Client() |
| c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { |
| logf("Dial") |
| c, err := net.Dial(network, ts.Listener.Addr().String()) |
| if err != nil { |
| logf("Dial error: %v", err) |
| return nil, err |
| } |
| return &writerFuncConn{ |
| Conn: c, |
| write: func(p []byte) (n int, err error) { |
| if atomic.AddInt32(&writeNumAtomic, 1) == 2 { |
| logf("intentional write failure") |
| return tc.failureN, tc.failureErr |
| } |
| logf("Write(%q)", p) |
| return c.Write(p) |
| }, |
| }, nil |
| } |
| |
| SetRoundTripRetried(func() { |
| logf("Retried.") |
| }) |
| defer SetRoundTripRetried(nil) |
| |
| for i := 0; i < 3; i++ { |
| t0 := time.Now() |
| req := tc.req() |
| res, err := c.Do(req) |
| if err != nil { |
| if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 { |
| mu.Lock() |
| got := logbuf.String() |
| mu.Unlock() |
| t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got) |
| } |
| t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse) |
| } |
| res.Body.Close() |
| if res.Request != req { |
| t.Errorf("Response.Request != original request; want identical Request") |
| } |
| } |
| |
| mu.Lock() |
| got := logbuf.String() |
| mu.Unlock() |
| want := fmt.Sprintf(`Dial |
| Write("%s") |
| Handler |
| intentional write failure |
| Retried. |
| Dial |
| Write("%s") |
| Handler |
| Write("%s") |
| Handler |
| `, tc.reqString, tc.reqString, tc.reqString) |
| if got != want { |
| t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) |
| } |
| }) |
| } |
| } |
| |
| // Issue 6981 |
| func TestTransportClosesBodyOnError(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| readBody := make(chan error, 1) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| _, err := io.ReadAll(r.Body) |
| readBody <- err |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| fakeErr := errors.New("fake error") |
| didClose := make(chan bool, 1) |
| req, _ := NewRequest("POST", ts.URL, struct { |
| io.Reader |
| io.Closer |
| }{ |
| io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), |
| closerFunc(func() error { |
| select { |
| case didClose <- true: |
| default: |
| } |
| return nil |
| }), |
| }) |
| res, err := c.Do(req) |
| if res != nil { |
| defer res.Body.Close() |
| } |
| if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { |
| t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) |
| } |
| select { |
| case err := <-readBody: |
| if err == nil { |
| t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") |
| } |
| case <-time.After(5 * time.Second): |
| t.Error("timeout waiting for server handler to complete") |
| } |
| select { |
| case <-didClose: |
| default: |
| t.Errorf("didn't see Body.Close") |
| } |
| } |
| |
| func TestTransportDialTLS(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| var mu sync.Mutex // guards following |
| var gotReq, didDial bool |
| |
| ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| mu.Lock() |
| gotReq = true |
| mu.Unlock() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { |
| mu.Lock() |
| didDial = true |
| mu.Unlock() |
| c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) |
| if err != nil { |
| return nil, err |
| } |
| return c, c.Handshake() |
| } |
| |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| mu.Lock() |
| if !gotReq { |
| t.Error("didn't get request") |
| } |
| if !didDial { |
| t.Error("didn't use dial hook") |
| } |
| } |
| |
| func TestTransportDialContext(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| var mu sync.Mutex // guards following |
| var gotReq bool |
| var receivedContext context.Context |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| mu.Lock() |
| gotReq = true |
| mu.Unlock() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { |
| mu.Lock() |
| receivedContext = ctx |
| mu.Unlock() |
| return net.Dial(netw, addr) |
| } |
| |
| req, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| ctx := context.WithValue(context.Background(), "some-key", "some-value") |
| res, err := c.Do(req.WithContext(ctx)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| mu.Lock() |
| if !gotReq { |
| t.Error("didn't get request") |
| } |
| if receivedContext != ctx { |
| t.Error("didn't receive correct context") |
| } |
| } |
| |
| func TestTransportDialTLSContext(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| var mu sync.Mutex // guards following |
| var gotReq bool |
| var receivedContext context.Context |
| |
| ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| mu.Lock() |
| gotReq = true |
| mu.Unlock() |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { |
| mu.Lock() |
| receivedContext = ctx |
| mu.Unlock() |
| c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) |
| if err != nil { |
| return nil, err |
| } |
| return c, c.HandshakeContext(ctx) |
| } |
| |
| req, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| ctx := context.WithValue(context.Background(), "some-key", "some-value") |
| res, err := c.Do(req.WithContext(ctx)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| mu.Lock() |
| if !gotReq { |
| t.Error("didn't get request") |
| } |
| if receivedContext != ctx { |
| t.Error("didn't receive correct context") |
| } |
| } |
| |
| // Test for issue 8755 |
| // Ensure that if a proxy returns an error, it is exposed by RoundTrip |
| func TestRoundTripReturnsProxyError(t *testing.T) { |
| badProxy := func(*Request) (*url.URL, error) { |
| return nil, errors.New("errorMessage") |
| } |
| |
| tr := &Transport{Proxy: badProxy} |
| |
| req, _ := NewRequest("GET", "http://example.com", nil) |
| |
| _, err := tr.RoundTrip(req) |
| |
| if err == nil { |
| t.Error("Expected proxy error to be returned by RoundTrip") |
| } |
| } |
| |
| // tests that putting an idle conn after a call to CloseIdleConns does return it |
| func TestTransportCloseIdleConnsThenReturn(t *testing.T) { |
| tr := &Transport{} |
| wantIdle := func(when string, n int) bool { |
| got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn |
| if got == n { |
| return true |
| } |
| t.Errorf("%s: idle conns = %d; want %d", when, got, n) |
| return false |
| } |
| wantIdle("start", 0) |
| if !tr.PutIdleTestConn("http", "example.com") { |
| t.Fatal("put failed") |
| } |
| if !tr.PutIdleTestConn("http", "example.com") { |
| t.Fatal("second put failed") |
| } |
| wantIdle("after put", 2) |
| tr.CloseIdleConnections() |
| if !tr.IsIdleForTesting() { |
| t.Error("should be idle after CloseIdleConnections") |
| } |
| wantIdle("after close idle", 0) |
| if tr.PutIdleTestConn("http", "example.com") { |
| t.Fatal("put didn't fail") |
| } |
| wantIdle("after second put", 0) |
| |
| tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode |
| if tr.IsIdleForTesting() { |
| t.Error("shouldn't be idle after QueueForIdleConnForTesting") |
| } |
| if !tr.PutIdleTestConn("http", "example.com") { |
| t.Fatal("after re-activation") |
| } |
| wantIdle("after final put", 1) |
| } |
| |
| // Test for issue 34282 |
| // Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn |
| func TestTransportTraceGotConnH2IdleConns(t *testing.T) { |
| tr := &Transport{} |
| wantIdle := func(when string, n int) bool { |
| got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 |
| if got == n { |
| return true |
| } |
| t.Errorf("%s: idle conns = %d; want %d", when, got, n) |
| return false |
| } |
| wantIdle("start", 0) |
| alt := funcRoundTripper(func() {}) |
| if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { |
| t.Fatal("put failed") |
| } |
| wantIdle("after put", 1) |
| ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ |
| GotConn: func(httptrace.GotConnInfo) { |
| // tr.getConn should leave it for the HTTP/2 alt to call GotConn. |
| t.Error("GotConn called") |
| }, |
| }) |
| req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil) |
| _, err := tr.RoundTrip(req) |
| if err != errFakeRoundTrip { |
| t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) |
| } |
| wantIdle("after round trip", 1) |
| } |
| |
| func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { |
| if testing.Short() { |
| t.Skip("skipping in short mode") |
| } |
| |
| trFunc := func(tr *Transport) { |
| tr.MaxConnsPerHost = 1 |
| tr.MaxIdleConnsPerHost = 1 |
| tr.IdleConnTimeout = 10 * time.Millisecond |
| } |
| cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) |
| defer cst.close() |
| |
| if _, err := cst.c.Get(cst.ts.URL); err != nil { |
| t.Fatalf("got error: %s", err) |
| } |
| |
| time.Sleep(100 * time.Millisecond) |
| got := make(chan error) |
| go func() { |
| if _, err := cst.c.Get(cst.ts.URL); err != nil { |
| got <- err |
| } |
| close(got) |
| }() |
| |
| timeout := time.NewTimer(5 * time.Second) |
| defer timeout.Stop() |
| select { |
| case err := <-got: |
| if err != nil { |
| t.Fatalf("got error: %s", err) |
| } |
| case <-timeout.C: |
| t.Fatal("request never completed") |
| } |
| } |
| |
| // This tests that a client requesting a content range won't also |
| // implicitly ask for gzip support. If they want that, they need to do it |
| // on their own. |
| // golang.org/issue/8923 |
| func TestTransportRangeAndGzip(t *testing.T) { |
| defer afterTest(t) |
| reqc := make(chan *Request, 1) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| reqc <- r |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| req, _ := NewRequest("GET", ts.URL, nil) |
| req.Header.Set("Range", "bytes=7-11") |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| select { |
| case r := <-reqc: |
| if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { |
| t.Error("Transport advertised gzip support in the Accept header") |
| } |
| if r.Header.Get("Range") == "" { |
| t.Error("no Range in request") |
| } |
| case <-time.After(10 * time.Second): |
| t.Fatal("timeout") |
| } |
| res.Body.Close() |
| } |
| |
| // Test for issue 10474 |
| func TestTransportResponseCancelRace(t *testing.T) { |
| defer afterTest(t) |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| // important that this response has a body. |
| var b [1024]byte |
| w.Write(b[:]) |
| })) |
| defer ts.Close() |
| tr := ts.Client().Transport.(*Transport) |
| |
| req, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res, err := tr.RoundTrip(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| // If we do an early close, Transport just throws the connection away and |
| // doesn't reuse it. In order to trigger the bug, it has to reuse the connection |
| // so read the body |
| if _, err := io.Copy(io.Discard, res.Body); err != nil { |
| t.Fatal(err) |
| } |
| |
| req2, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| tr.CancelRequest(req) |
| res, err = tr.RoundTrip(req2) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| } |
| |
| // Test for issue 19248: Content-Encoding's value is case insensitive. |
| func TestTransportContentEncodingCaseInsensitive(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| for _, ce := range []string{"gzip", "GZIP"} { |
| ce := ce |
| t.Run(ce, func(t *testing.T) { |
| const encodedString = "Hello Gopher" |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Content-Encoding", ce) |
| gz := gzip.NewWriter(w) |
| gz.Write([]byte(encodedString)) |
| gz.Close() |
| })) |
| defer ts.Close() |
| |
| res, err := ts.Client().Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| body, err := io.ReadAll(res.Body) |
| res.Body.Close() |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| if string(body) != encodedString { |
| t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) |
| } |
| }) |
| } |
| } |
| |
| func TestTransportDialCancelRace(t *testing.T) { |
| defer afterTest(t) |
| |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) |
| defer ts.Close() |
| tr := ts.Client().Transport.(*Transport) |
| |
| req, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| SetEnterRoundTripHook(func() { |
| tr.CancelRequest(req) |
| }) |
| defer SetEnterRoundTripHook(nil) |
| res, err := tr.RoundTrip(req) |
| if err != ExportErrRequestCanceled { |
| t.Errorf("expected canceled request error; got %v", err) |
| if err == nil { |
| res.Body.Close() |
| } |
| } |
| } |
| |
| // logWritesConn is a net.Conn that logs each Write call to writes |
| // and then proxies to w. |
| // It proxies Read calls to a reader it receives from rch. |
| type logWritesConn struct { |
| net.Conn // nil. crash on use. |
| |
| w io.Writer |
| |
| rch <-chan io.Reader |
| r io.Reader // nil until received by rch |
| |
| mu sync.Mutex |
| writes []string |
| } |
| |
| func (c *logWritesConn) Write(p []byte) (n int, err error) { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| c.writes = append(c.writes, string(p)) |
| return c.w.Write(p) |
| } |
| |
| func (c *logWritesConn) Read(p []byte) (n int, err error) { |
| if c.r == nil { |
| c.r = <-c.rch |
| } |
| return c.r.Read(p) |
| } |
| |
| func (c *logWritesConn) Close() error { return nil } |
| |
| // Issue 6574 |
| func TestTransportFlushesBodyChunks(t *testing.T) { |
| defer afterTest(t) |
| resBody := make(chan io.Reader, 1) |
| connr, connw := io.Pipe() // connection pipe pair |
| lw := &logWritesConn{ |
| rch: resBody, |
| w: connw, |
| } |
| tr := &Transport{ |
| Dial: func(network, addr string) (net.Conn, error) { |
| return lw, nil |
| }, |
| } |
| bodyr, bodyw := io.Pipe() // body pipe pair |
| go func() { |
| defer bodyw.Close() |
| for i := 0; i < 3; i++ { |
| fmt.Fprintf(bodyw, "num%d\n", i) |
| } |
| }() |
| resc := make(chan *Response) |
| go func() { |
| req, _ := NewRequest("POST", "http://localhost:8080", bodyr) |
| req.Header.Set("User-Agent", "x") // known value for test |
| res, err := tr.RoundTrip(req) |
| if err != nil { |
| t.Errorf("RoundTrip: %v", err) |
| close(resc) |
| return |
| } |
| resc <- res |
| |
| }() |
| // Fully consume the request before checking the Write log vs. want. |
| req, err := ReadRequest(bufio.NewReader(connr)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| io.Copy(io.Discard, req.Body) |
| |
| // Unblock the transport's roundTrip goroutine. |
| resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") |
| res, ok := <-resc |
| if !ok { |
| return |
| } |
| defer res.Body.Close() |
| |
| want := []string{ |
| "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", |
| "5\r\nnum0\n\r\n", |
| "5\r\nnum1\n\r\n", |
| "5\r\nnum2\n\r\n", |
| "0\r\n\r\n", |
| } |
| if !reflect.DeepEqual(lw.writes, want) { |
| t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) |
| } |
| } |
| |
| // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. |
| func TestTransportFlushesRequestHeader(t *testing.T) { |
| defer afterTest(t) |
| gotReq := make(chan struct{}) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| close(gotReq) |
| })) |
| defer cst.close() |
| |
| pr, pw := io.Pipe() |
| req, err := NewRequest("POST", cst.ts.URL, pr) |
| if err != nil { |
| t.Fatal(err) |
| } |
| gotRes := make(chan struct{}) |
| go func() { |
| defer close(gotRes) |
| res, err := cst.tr.RoundTrip(req) |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| res.Body.Close() |
| }() |
| |
| select { |
| case <-gotReq: |
| pw.Close() |
| case <-time.After(5 * time.Second): |
| t.Fatal("timeout waiting for handler to get request") |
| } |
| <-gotRes |
| } |
| |
| // Issue 11745. |
| func TestTransportPrefersResponseOverWriteError(t *testing.T) { |
| if testing.Short() { |
| t.Skip("skipping in short mode") |
| } |
| defer afterTest(t) |
| const contentLengthLimit = 1024 * 1024 // 1MB |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.ContentLength >= contentLengthLimit { |
| w.WriteHeader(StatusBadRequest) |
| r.Body.Close() |
| return |
| } |
| w.WriteHeader(StatusOK) |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| fail := 0 |
| count := 100 |
| bigBody := strings.Repeat("a", contentLengthLimit*2) |
| for i := 0; i < count; i++ { |
| req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) |
| if err != nil { |
| t.Fatal(err) |
| } |
| resp, err := c.Do(req) |
| if err != nil { |
| fail++ |
| t.Logf("%d = %#v", i, err) |
| if ue, ok := err.(*url.Error); ok { |
| t.Logf("urlErr = %#v", ue.Err) |
| if ne, ok := ue.Err.(*net.OpError); ok { |
| t.Logf("netOpError = %#v", ne.Err) |
| } |
| } |
| } else { |
| resp.Body.Close() |
| if resp.StatusCode != 400 { |
| t.Errorf("Expected status code 400, got %v", resp.Status) |
| } |
| } |
| } |
| if fail > 0 { |
| t.Errorf("Failed %v out of %v\n", fail, count) |
| } |
| } |
| |
| func TestTransportAutomaticHTTP2(t *testing.T) { |
| testTransportAutoHTTP(t, &Transport{}, true) |
| } |
| |
| func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { |
| testTransportAutoHTTP(t, &Transport{ |
| ForceAttemptHTTP2: true, |
| TLSClientConfig: new(tls.Config), |
| }, true) |
| } |
| |
| // golang.org/issue/14391: also check DefaultTransport |
| func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { |
| testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) |
| } |
| |
| func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { |
| testTransportAutoHTTP(t, &Transport{ |
| TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper), |
| }, false) |
| } |
| |
| func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) { |
| testTransportAutoHTTP(t, &Transport{ |
| TLSClientConfig: new(tls.Config), |
| }, false) |
| } |
| |
| func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) { |
| testTransportAutoHTTP(t, &Transport{ |
| ExpectContinueTimeout: 1 * time.Second, |
| }, true) |
| } |
| |
| func TestTransportAutomaticHTTP2_Dial(t *testing.T) { |
| var d net.Dialer |
| testTransportAutoHTTP(t, &Transport{ |
| Dial: d.Dial, |
| }, false) |
| } |
| |
| func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { |
| var d net.Dialer |
| testTransportAutoHTTP(t, &Transport{ |
| DialContext: d.DialContext, |
| }, false) |
| } |
| |
| func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { |
| testTransportAutoHTTP(t, &Transport{ |
| DialTLS: func(network, addr string) (net.Conn, error) { |
| panic("unused") |
| }, |
| }, false) |
| } |
| |
| func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { |
| CondSkipHTTP2(t) |
| _, err := tr.RoundTrip(new(Request)) |
| if err == nil { |
| t.Error("expected error from RoundTrip") |
| } |
| if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { |
| t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) |
| } |
| } |
| |
| // Issue 13633: there was a race where we returned bodyless responses |
| // to callers before recycling the persistent connection, which meant |
| // a client doing two subsequent requests could end up on different |
| // connections. It's somewhat harmless but enough tests assume it's |
| // not true in order to test other things that it's worth fixing. |
| // Plus it's nice to be consistent and not have timing-dependent |
| // behavior. |
| func TestTransportReuseConnEmptyResponseBody(t *testing.T) { |
| defer afterTest(t) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("X-Addr", r.RemoteAddr) |
| // Empty response body. |
| })) |
| defer cst.close() |
| n := 100 |
| if testing.Short() { |
| n = 10 |
| } |
| var firstAddr string |
| for i := 0; i < n; i++ { |
| res, err := cst.c.Get(cst.ts.URL) |
| if err != nil { |
| log.Fatal(err) |
| } |
| addr := res.Header.Get("X-Addr") |
| if i == 0 { |
| firstAddr = addr |
| } else if addr != firstAddr { |
| t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) |
| } |
| res.Body.Close() |
| } |
| } |
| |
| // Issue 13839 |
| func TestNoCrashReturningTransportAltConn(t *testing.T) { |
| cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) |
| if err != nil { |
| t.Fatal(err) |
| } |
| ln := newLocalListener(t) |
| defer ln.Close() |
| |
| var wg sync.WaitGroup |
| SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) |
| defer SetPendingDialHooks(nil, nil) |
| |
| testDone := make(chan struct{}) |
| defer close(testDone) |
| go func() { |
| tln := tls.NewListener(ln, &tls.Config{ |
| NextProtos: []string{"foo"}, |
| Certificates: []tls.Certificate{cert}, |
| }) |
| sc, err := tln.Accept() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| if err := sc.(*tls.Conn).Handshake(); err != nil { |
| t.Error(err) |
| return |
| } |
| <-testDone |
| sc.Close() |
| }() |
| |
| addr := ln.Addr().String() |
| |
| req, _ := NewRequest("GET", "https://fake.tld/", nil) |
| cancel := make(chan struct{}) |
| req.Cancel = cancel |
| |
| doReturned := make(chan bool, 1) |
| madeRoundTripper := make(chan bool, 1) |
| |
| tr := &Transport{ |
| DisableKeepAlives: true, |
| TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ |
| "foo": func(authority string, c *tls.Conn) RoundTripper { |
| madeRoundTripper <- true |
| return funcRoundTripper(func() { |
| t.Error("foo RoundTripper should not be called") |
| }) |
| }, |
| }, |
| Dial: func(_, _ string) (net.Conn, error) { |
| panic("shouldn't be called") |
| }, |
| DialTLS: func(_, _ string) (net.Conn, error) { |
| tc, err := tls.Dial("tcp", addr, &tls.Config{ |
| InsecureSkipVerify: true, |
| NextProtos: []string{"foo"}, |
| }) |
| if err != nil { |
| return nil, err |
| } |
| if err := tc.Handshake(); err != nil { |
| return nil, err |
| } |
| close(cancel) |
| <-doReturned |
| return tc, nil |
| }, |
| } |
| c := &Client{Transport: tr} |
| |
| _, err = c.Do(req) |
| if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn { |
| t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) |
| } |
| |
| doReturned <- true |
| <-madeRoundTripper |
| wg.Wait() |
| } |
| |
| func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { |
| testTransportReuseConnection_Gzip(t, true) |
| } |
| |
| func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { |
| testTransportReuseConnection_Gzip(t, false) |
| } |
| |
| // Make sure we re-use underlying TCP connection for gzipped responses too. |
| func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { |
| setParallel(t) |
| defer afterTest(t) |
| addr := make(chan string, 2) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| addr <- r.RemoteAddr |
| w.Header().Set("Content-Encoding", "gzip") |
| if chunked { |
| w.(Flusher).Flush() |
| } |
| w.Write(rgz) // arbitrary gzip response |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| |
| for i := 0; i < 2; i++ { |
| res, err := c.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| buf := make([]byte, len(rgz)) |
| if n, err := io.ReadFull(res.Body, buf); err != nil { |
| t.Errorf("%d. ReadFull = %v, %v", i, n, err) |
| } |
| // Note: no res.Body.Close call. It should work without it, |
| // since the flate.Reader's internal buffering will hit EOF |
| // and that should be sufficient. |
| } |
| a1, a2 := <-addr, <-addr |
| if a1 != a2 { |
| t.Fatalf("didn't reuse connection") |
| } |
| } |
| |
| func TestTransportResponseHeaderLength(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.URL.Path == "/long" { |
| w.Header().Set("Long", strings.Repeat("a", 1<<20)) |
| } |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 |
| |
| if res, err := c.Get(ts.URL); err != nil { |
| t.Fatal(err) |
| } else { |
| res.Body.Close() |
| } |
| |
| res, err := c.Get(ts.URL + "/long") |
| if err == nil { |
| defer res.Body.Close() |
| var n int64 |
| for k, vv := range res.Header { |
| for _, v := range vv { |
| n += int64(len(k)) + int64(len(v)) |
| } |
| } |
| t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) |
| } |
| if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { |
| t.Errorf("got error: %v; want %q", err, want) |
| } |
| } |
| |
| func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) } |
| func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) } |
| |
| // test a non-nil httptrace.ClientTrace but with all hooks set to zero. |
| func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) } |
| func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) } |
| |
| func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { |
| defer afterTest(t) |
| const resBody = "some body" |
| gotWroteReqEvent := make(chan struct{}, 500) |
| cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method == "GET" { |
| // Do nothing for the second request. |
| return |
| } |
| if _, err := io.ReadAll(r.Body); err != nil { |
| t.Error(err) |
| } |
| if !noHooks { |
| select { |
| case <-gotWroteReqEvent: |
| case <-time.After(5 * time.Second): |
| t.Error("timeout waiting for WroteRequest event") |
| } |
| } |
| io.WriteString(w, resBody) |
| })) |
| defer cst.close() |
| |
| cst.tr.ExpectContinueTimeout = 1 * time.Second |
| |
| var mu sync.Mutex // guards buf |
| var buf bytes.Buffer |
| logf := func(format string, args ...any) { |
| mu.Lock() |
| defer mu.Unlock() |
| fmt.Fprintf(&buf, format, args...) |
| buf.WriteByte('\n') |
| } |
| |
| addrStr := cst.ts.Listener.Addr().String() |
| ip, port, err := net.SplitHostPort(addrStr) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Install a fake DNS server. |
| ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { |
| if host != "dns-is-faked.golang" { |
| t.Errorf("unexpected DNS host lookup for %q/%q", network, host) |
| return nil, nil |
| } |
| return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil |
| }) |
| |
| body := "some body" |
| req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) |
| req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} |
| trace := &httptrace.ClientTrace{ |
| GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, |
| GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, |
| GotFirstResponseByte: func() { logf("first response byte") }, |
| PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, |
| DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, |
| DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, |
| ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, |
| ConnectDone: func(network, addr string, err error) { |
| if err != nil { |
| t.Errorf("ConnectDone: %v", err) |
| } |
| logf("ConnectDone: connected to %s %s = %v", network, addr, err) |
| }, |
| WroteHeaderField: func(key string, value []string) { |
| logf("WroteHeaderField: %s: %v", key, value) |
| }, |
| WroteHeaders: func() { |
| logf("WroteHeaders") |
| }, |
| Wait100Continue: func() { logf("Wait100Continue") }, |
| Got100Continue: func() { logf("Got100Continue") }, |
| WroteRequest: func(e httptrace.WroteRequestInfo) { |
| logf("WroteRequest: %+v", e) |
| gotWroteReqEvent <- struct{}{} |
| }, |
| } |
| if h2 { |
| trace.TLSHandshakeStart = func() { logf("tls handshake start") } |
| trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { |
| logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) |
| } |
| } |
| if noHooks { |
| // zero out all func pointers, trying to get some path to crash |
| *trace = httptrace.ClientTrace{} |
| } |
| req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) |
| |
| req.Header.Set("Expect", "100-continue") |
| res, err := cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| logf("got roundtrip.response") |
| slurp, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| logf("consumed body") |
| if string(slurp) != resBody || res.StatusCode != 200 { |
| t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) |
| } |
| res.Body.Close() |
| |
| if noHooks { |
| // Done at this point. Just testing a full HTTP |
| // requests can happen with a trace pointing to a zero |
| // ClientTrace, full of nil func pointers. |
| return |
| } |
| |
| mu.Lock() |
| got := buf.String() |
| mu.Unlock() |
| |
| wantOnce := func(sub string) { |
| if strings.Count(got, sub) != 1 { |
| t.Errorf("expected substring %q exactly once in output.", sub) |
| } |
| } |
| wantOnceOrMore := func(sub string) { |
| if strings.Count(got, sub) == 0 { |
| t.Errorf("expected substring %q at least once in output.", sub) |
| } |
| } |
| wantOnce("Getting conn for dns-is-faked.golang:" + port) |
| wantOnce("DNS start: {Host:dns-is-faked.golang}") |
| wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}") |
| wantOnce("got conn: {") |
| wantOnceOrMore("Connecting to tcp " + addrStr) |
| wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") |
| wantOnce("Reused:false WasIdle:false IdleTime:0s") |
| wantOnce("first response byte") |
| if h2 { |
| wantOnce("tls handshake start") |
| wantOnce("tls handshake done") |
| } else { |
| wantOnce("PutIdleConn = <nil>") |
| wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") |
| // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the |
| // WroteHeaderField hook is not yet implemented in h2.) |
| wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) |
| wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) |
| wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") |
| wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") |
| } |
| wantOnce("WroteHeaders") |
| wantOnce("Wait100Continue") |
| wantOnce("Got100Continue") |
| wantOnce("WroteRequest: {Err:<nil>}") |
| if strings.Contains(got, " to udp ") { |
| t.Errorf("should not see UDP (DNS) connections") |
| } |
| if t.Failed() { |
| t.Errorf("Output:\n%s", got) |
| } |
| |
| // And do a second request: |
| req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) |
| req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) |
| res, err = cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if res.StatusCode != 200 { |
| t.Fatal(res.Status) |
| } |
| res.Body.Close() |
| |
| mu.Lock() |
| got = buf.String() |
| mu.Unlock() |
| |
| sub := "Getting conn for dns-is-faked.golang:" |
| if gotn, want := strings.Count(got, sub), 2; gotn != want { |
| t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) |
| } |
| |
| } |
| |
| func TestTransportEventTraceTLSVerify(t *testing.T) { |
| var mu sync.Mutex |
| var buf bytes.Buffer |
| logf := func(format string, args ...any) { |
| mu.Lock() |
| defer mu.Unlock() |
| fmt.Fprintf(&buf, format, args...) |
| buf.WriteByte('\n') |
| } |
| |
| ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| t.Error("Unexpected request") |
| })) |
| defer ts.Close() |
| ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { |
| logf("%s", p) |
| return len(p), nil |
| }), "", 0) |
| |
| certpool := x509.NewCertPool() |
| certpool.AddCert(ts.Certificate()) |
| |
| c := &Client{Transport: &Transport{ |
| TLSClientConfig: &tls.Config{ |
| ServerName: "dns-is-faked.golang", |
| RootCAs: certpool, |
| }, |
| }} |
| |
| trace := &httptrace.ClientTrace{ |
| TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, |
| TLSHandshakeDone: func(s tls.ConnectionState, err error) { |
| logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) |
| }, |
| } |
| |
| req, _ := NewRequest("GET", ts.URL, nil) |
| req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) |
| _, err := c.Do(req) |
| if err == nil { |
| t.Error("Expected request to fail TLS verification") |
| } |
| |
| mu.Lock() |
| got := buf.String() |
| mu.Unlock() |
| |
| wantOnce := func(sub string) { |
| if strings.Count(got, sub) != 1 { |
| t.Errorf("expected substring %q exactly once in output.", sub) |
| } |
| } |
| |
| wantOnce("TLSHandshakeStart") |
| wantOnce("TLSHandshakeDone") |
| wantOnce("err = x509: certificate is valid for example.com") |
| |
| if t.Failed() { |
| t.Errorf("Output:\n%s", got) |
| } |
| } |
| |
| var ( |
| isDNSHijackedOnce sync.Once |
| isDNSHijacked bool |
| ) |
| |
| func skipIfDNSHijacked(t *testing.T) { |
| // Skip this test if the user is using a shady/ISP |
| // DNS server hijacking queries. |
| // See issues 16732, 16716. |
| isDNSHijackedOnce.Do(func() { |
| addrs, _ := net.LookupHost("dns-should-not-resolve.golang") |
| isDNSHijacked = len(addrs) != 0 |
| }) |
| if isDNSHijacked { |
| t.Skip("skipping; test requires non-hijacking DNS server") |
| } |
| } |
| |
| func TestTransportEventTraceRealDNS(t *testing.T) { |
| skipIfDNSHijacked(t) |
| defer afterTest(t) |
| tr := &Transport{} |
| defer tr.CloseIdleConnections() |
| c := &Client{Transport: tr} |
| |
| var mu sync.Mutex // guards buf |
| var buf bytes.Buffer |
| logf := func(format string, args ...any) { |
| mu.Lock() |
| defer mu.Unlock() |
| fmt.Fprintf(&buf, format, args...) |
| buf.WriteByte('\n') |
| } |
| |
| req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) |
| trace := &httptrace.ClientTrace{ |
| DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, |
| DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, |
| ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, |
| ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, |
| } |
| req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) |
| |
| resp, err := c.Do(req) |
| if err == nil { |
| resp.Body.Close() |
| t.Fatal("expected error during DNS lookup") |
| } |
| |
| mu.Lock() |
| got := buf.String() |
| mu.Unlock() |
| |
| wantSub := func(sub string) { |
| if !strings.Contains(got, sub) { |
| t.Errorf("expected substring %q in output.", sub) |
| } |
| } |
| wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") |
| wantSub("DNSDone: {Addrs:[] Err:") |
| if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { |
| t.Errorf("should not see Connect events") |
| } |
| if t.Failed() { |
| t.Errorf("Output:\n%s", got) |
| } |
| } |
| |
| // Issue 14353: port can only contain digits. |
| func TestTransportRejectsAlphaPort(t *testing.T) { |
| res, err := Get("http://dummy.tld:123foo/bar") |
| if err == nil { |
| res.Body.Close() |
| t.Fatal("unexpected success") |
| } |
| ue, ok := err.(*url.Error) |
| if !ok { |
| t.Fatalf("got %#v; want *url.Error", err) |
| } |
| got := ue.Err.Error() |
| want := `invalid port ":123foo" after host` |
| if got != want { |
| t.Errorf("got error %q; want %q", got, want) |
| } |
| } |
| |
| // Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 |
| // connections. The http2 test is done in TestTransportEventTrace_h2 |
| func TestTLSHandshakeTrace(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) |
| defer ts.Close() |
| |
| var mu sync.Mutex |
| var start, done bool |
| trace := &httptrace.ClientTrace{ |
| TLSHandshakeStart: func() { |
| mu.Lock() |
| defer mu.Unlock() |
| start = true |
| }, |
| TLSHandshakeDone: func(s tls.ConnectionState, err error) { |
| mu.Lock() |
| defer mu.Unlock() |
| done = true |
| if err != nil { |
| t.Fatal("Expected error to be nil but was:", err) |
| } |
| }, |
| } |
| |
| c := ts.Client() |
| req, err := NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Fatal("Unable to construct test request:", err) |
| } |
| req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) |
| |
| r, err := c.Do(req) |
| if err != nil { |
| t.Fatal("Unexpected error making request:", err) |
| } |
| r.Body.Close() |
| mu.Lock() |
| defer mu.Unlock() |
| if !start { |
| t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") |
| } |
| if !done { |
| t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") |
| } |
| } |
| |
| func TestTransportMaxIdleConns(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| // No body for convenience. |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| tr.MaxIdleConns = 4 |
| |
| ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) { |
| return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil |
| }) |
| |
| hitHost := func(n int) { |
| req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil) |
| req = req.WithContext(ctx) |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| } |
| for i := 0; i < 4; i++ { |
| hitHost(i) |
| } |
| want := []string{ |
| "|http|host-0.dns-is-faked.golang:" + port, |
| "|http|host-1.dns-is-faked.golang:" + port, |
| "|http|host-2.dns-is-faked.golang:" + port, |
| "|http|host-3.dns-is-faked.golang:" + port, |
| } |
| if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { |
| t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want) |
| } |
| |
| // Now hitting the 5th host should kick out the first host: |
| hitHost(4) |
| want = []string{ |
| "|http|host-1.dns-is-faked.golang:" + port, |
| "|http|host-2.dns-is-faked.golang:" + port, |
| "|http|host-3.dns-is-faked.golang:" + port, |
| "|http|host-4.dns-is-faked.golang:" + port, |
| } |
| if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { |
| t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want) |
| } |
| } |
| |
| func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } |
| func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } |
| func testTransportIdleConnTimeout(t *testing.T, h2 bool) { |
| if testing.Short() { |
| t.Skip("skipping in short mode") |
| } |
| defer afterTest(t) |
| |
| const timeout = 1 * time.Second |
| |
| cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { |
| // No body for convenience. |
| })) |
| defer cst.close() |
| tr := cst.tr |
| tr.IdleConnTimeout = timeout |
| defer tr.CloseIdleConnections() |
| c := &Client{Transport: tr} |
| |
| idleConns := func() []string { |
| if h2 { |
| return tr.IdleConnStrsForTesting_h2() |
| } else { |
| return tr.IdleConnStrsForTesting() |
| } |
| } |
| |
| var conn string |
| doReq := func(n int) { |
| req, _ := NewRequest("GET", cst.ts.URL, nil) |
| req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ |
| PutIdleConn: func(err error) { |
| if err != nil { |
| t.Errorf("failed to keep idle conn: %v", err) |
| } |
| }, |
| })) |
| res, err := c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res.Body.Close() |
| conns := idleConns() |
| if len(conns) != 1 { |
| t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) |
| } |
| if conn == "" { |
| conn = conns[0] |
| } |
| if conn != conns[0] { |
| t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n) |
| } |
| } |
| for i := 0; i < 3; i++ { |
| doReq(i) |
| time.Sleep(timeout / 2) |
| } |
| time.Sleep(timeout * 3 / 2) |
| if got := idleConns(); len(got) != 0 { |
| t.Errorf("idle conns = %q; want none", got) |
| } |
| } |
| |
| // Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an |
| // HTTP/2 connection was established but its caller no longer |
| // wanted it. (Assuming the connection cache was enabled, which it is |
| // by default) |
| // |
| // This test reproduced the crash by setting the IdleConnTimeout low |
| // (to make the test reasonable) and then making a request which is |
| // canceled by the DialTLS hook, which then also waits to return the |
| // real connection until after the RoundTrip saw the error. Then we |
| // know the successful tls.Dial from DialTLS will need to go into the |
| // idle pool. Then we give it a of time to explode. |
| func TestIdleConnH2Crash(t *testing.T) { |
| setParallel(t) |
| cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| // nothing |
| })) |
| defer cst.close() |
| |
| ctx, cancel := context.WithCancel(context.Background()) |
| defer cancel() |
| |
| sawDoErr := make(chan bool, 1) |
| testDone := make(chan struct{}) |
| defer close(testDone) |
| |
| cst.tr.IdleConnTimeout = 5 * time.Millisecond |
| cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { |
| c, err := tls.Dial(network, addr, &tls.Config{ |
| InsecureSkipVerify: true, |
| NextProtos: []string{"h2"}, |
| }) |
| if err != nil { |
| t.Error(err) |
| return nil, err |
| } |
| if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { |
| t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") |
| c.Close() |
| return nil, errors.New("bogus") |
| } |
| |
| cancel() |
| |
| failTimer := time.NewTimer(5 * time.Second) |
| defer failTimer.Stop() |
| select { |
| case <-sawDoErr: |
| case <-testDone: |
| case <-failTimer.C: |
| t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") |
| } |
| return c, nil |
| } |
| |
| req, _ := NewRequest("GET", cst.ts.URL, nil) |
| req = req.WithContext(ctx) |
| res, err := cst.c.Do(req) |
| if err == nil { |
| res.Body.Close() |
| t.Fatal("unexpected success") |
| } |
| sawDoErr <- true |
| |
| // Wait for the explosion. |
| time.Sleep(cst.tr.IdleConnTimeout * 10) |
| } |
| |
| type funcConn struct { |
| net.Conn |
| read func([]byte) (int, error) |
| write func([]byte) (int, error) |
| } |
| |
| func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } |
| func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } |
| func (c funcConn) Close() error { return nil } |
| |
| // Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek |
| // back to the caller. |
| func TestTransportReturnsPeekError(t *testing.T) { |
| errValue := errors.New("specific error value") |
| |
| wrote := make(chan struct{}) |
| var wroteOnce sync.Once |
| |
| tr := &Transport{ |
| Dial: func(network, addr string) (net.Conn, error) { |
| c := funcConn{ |
| read: func([]byte) (int, error) { |
| <-wrote |
| return 0, errValue |
| }, |
| write: func(p []byte) (int, error) { |
| wroteOnce.Do(func() { close(wrote) }) |
| return len(p), nil |
| }, |
| } |
| return c, nil |
| }, |
| } |
| _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) |
| if err != errValue { |
| t.Errorf("error = %#v; want %v", err, errValue) |
| } |
| } |
| |
| // Issue 13835: international domain names should work |
| func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) } |
| func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) } |
| func testTransportIDNA(t *testing.T, h2 bool) { |
| defer afterTest(t) |
| |
| const uniDomain = "гофер.го" |
| const punyDomain = "xn--c1ae0ajs.xn--c1aw" |
| |
| var port string |
| cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { |
| want := punyDomain + ":" + port |
| if r.Host != want { |
| t.Errorf("Host header = %q; want %q", r.Host, want) |
| } |
| if h2 { |
| if r.TLS == nil { |
| t.Errorf("r.TLS == nil") |
| } else if r.TLS.ServerName != punyDomain { |
| t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) |
| } |
| } |
| w.Header().Set("Hit-Handler", "1") |
| })) |
| defer cst.close() |
| |
| ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| // Install a fake DNS server. |
| ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { |
| if host != punyDomain { |
| t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain) |
| return nil, nil |
| } |
| return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil |
| }) |
| |
| req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) |
| trace := &httptrace.ClientTrace{ |
| GetConn: func(hostPort string) { |
| want := net.JoinHostPort(punyDomain, port) |
| if hostPort != want { |
| t.Errorf("getting conn for %q; want %q", hostPort, want) |
| } |
| }, |
| DNSStart: func(e httptrace.DNSStartInfo) { |
| if e.Host != punyDomain { |
| t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) |
| } |
| }, |
| } |
| req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) |
| |
| res, err := cst.tr.RoundTrip(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| if res.Header.Get("Hit-Handler") != "1" { |
| out, err := httputil.DumpResponse(res, true) |
| if err != nil { |
| t.Fatal(err) |
| } |
| t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) |
| } |
| } |
| |
| // Issue 13290: send User-Agent in proxy CONNECT |
| func TestTransportProxyConnectHeader(t *testing.T) { |
| defer afterTest(t) |
| reqc := make(chan *Request, 1) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method != "CONNECT" { |
| t.Errorf("method = %q; want CONNECT", r.Method) |
| } |
| reqc <- r |
| c, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Errorf("Hijack: %v", err) |
| return |
| } |
| c.Close() |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { |
| return url.Parse(ts.URL) |
| } |
| c.Transport.(*Transport).ProxyConnectHeader = Header{ |
| "User-Agent": {"foo"}, |
| "Other": {"bar"}, |
| } |
| |
| res, err := c.Get("https://dummy.tld/") // https to force a CONNECT |
| if err == nil { |
| res.Body.Close() |
| t.Errorf("unexpected success") |
| } |
| select { |
| case <-time.After(3 * time.Second): |
| t.Fatal("timeout") |
| case r := <-reqc: |
| if got, want := r.Header.Get("User-Agent"), "foo"; got != want { |
| t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) |
| } |
| if got, want := r.Header.Get("Other"), "bar"; got != want { |
| t.Errorf("CONNECT request Other = %q; want %q", got, want) |
| } |
| } |
| } |
| |
| func TestTransportProxyGetConnectHeader(t *testing.T) { |
| defer afterTest(t) |
| reqc := make(chan *Request, 1) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method != "CONNECT" { |
| t.Errorf("method = %q; want CONNECT", r.Method) |
| } |
| reqc <- r |
| c, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Errorf("Hijack: %v", err) |
| return |
| } |
| c.Close() |
| })) |
| defer ts.Close() |
| |
| c := ts.Client() |
| c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { |
| return url.Parse(ts.URL) |
| } |
| // These should be ignored: |
| c.Transport.(*Transport).ProxyConnectHeader = Header{ |
| "User-Agent": {"foo"}, |
| "Other": {"bar"}, |
| } |
| c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { |
| return Header{ |
| "User-Agent": {"foo2"}, |
| "Other": {"bar2"}, |
| }, nil |
| } |
| |
| res, err := c.Get("https://dummy.tld/") // https to force a CONNECT |
| if err == nil { |
| res.Body.Close() |
| t.Errorf("unexpected success") |
| } |
| select { |
| case <-time.After(3 * time.Second): |
| t.Fatal("timeout") |
| case r := <-reqc: |
| if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { |
| t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) |
| } |
| if got, want := r.Header.Get("Other"), "bar2"; got != want { |
| t.Errorf("CONNECT request Other = %q; want %q", got, want) |
| } |
| } |
| } |
| |
| var errFakeRoundTrip = errors.New("fake roundtrip") |
| |
| type funcRoundTripper func() |
| |
| func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) { |
| fn() |
| return nil, errFakeRoundTrip |
| } |
| |
| func wantBody(res *Response, err error, want string) error { |
| if err != nil { |
| return err |
| } |
| slurp, err := io.ReadAll(res.Body) |
| if err != nil { |
| return fmt.Errorf("error reading body: %v", err) |
| } |
| if string(slurp) != want { |
| return fmt.Errorf("body = %q; want %q", slurp, want) |
| } |
| if err := res.Body.Close(); err != nil { |
| return fmt.Errorf("body Close = %v", err) |
| } |
| return nil |
| } |
| |
| func newLocalListener(t *testing.T) net.Listener { |
| ln, err := net.Listen("tcp", "127.0.0.1:0") |
| if err != nil { |
| ln, err = net.Listen("tcp6", "[::1]:0") |
| } |
| if err != nil { |
| t.Fatal(err) |
| } |
| return ln |
| } |
| |
| type countCloseReader struct { |
| n *int |
| io.Reader |
| } |
| |
| func (cr countCloseReader) Close() error { |
| (*cr.n)++ |
| return nil |
| } |
| |
| // rgz is a gzip quine that uncompresses to itself. |
| var rgz = []byte{ |
| 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, |
| 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, |
| 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, |
| 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, |
| 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, |
| 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, |
| 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, |
| 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, |
| 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, |
| 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, |
| 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, |
| 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, |
| 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, |
| 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, |
| 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, |
| 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, |
| 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, |
| 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, |
| 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, |
| 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, |
| 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, |
| 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, |
| 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, |
| 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, |
| 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, |
| 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, |
| 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, |
| 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, |
| 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, |
| 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, |
| 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, |
| 0x00, 0x00, |
| } |
| |
| // Ensure that a missing status doesn't make the server panic |
| // See Issue https://golang.org/issues/21701 |
| func TestMissingStatusNoPanic(t *testing.T) { |
| t.Parallel() |
| |
| const want = "unknown status code" |
| |
| ln := newLocalListener(t) |
| addr := ln.Addr().String() |
| done := make(chan bool) |
| fullAddrURL := fmt.Sprintf("http://%s", addr) |
| raw := "HTTP/1.1 400\r\n" + |
| "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + |
| "Content-Type: text/html; charset=utf-8\r\n" + |
| "Content-Length: 10\r\n" + |
| "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + |
| "Vary: Accept-Encoding\r\n\r\n" + |
| "Aloha Olaa" |
| |
| go func() { |
| defer close(done) |
| |
| conn, _ := ln.Accept() |
| if conn != nil { |
| io.WriteString(conn, raw) |
| io.ReadAll(conn) |
| conn.Close() |
| } |
| }() |
| |
| proxyURL, err := url.Parse(fullAddrURL) |
| if err != nil { |
| t.Fatalf("proxyURL: %v", err) |
| } |
| |
| tr := &Transport{Proxy: ProxyURL(proxyURL)} |
| |
| req, _ := NewRequest("GET", "https://golang.org/", nil) |
| res, err, panicked := doFetchCheckPanic(tr, req) |
| if panicked { |
| t.Error("panicked, expecting an error") |
| } |
| if res != nil && res.Body != nil { |
| io.Copy(io.Discard, res.Body) |
| res.Body.Close() |
| } |
| |
| if err == nil || !strings.Contains(err.Error(), want) { |
| t.Errorf("got=%v want=%q", err, want) |
| } |
| |
| ln.Close() |
| <-done |
| } |
| |
| func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) { |
| defer func() { |
| if r := recover(); r != nil { |
| panicked = true |
| } |
| }() |
| res, err = tr.RoundTrip(req) |
| return |
| } |
| |
| // Issue 22330: do not allow the response body to be read when the status code |
| // forbids a response body. |
| func TestNoBodyOnChunked304Response(t *testing.T) { |
| defer afterTest(t) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, buf, _ := w.(Hijacker).Hijack() |
| buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) |
| buf.Flush() |
| conn.Close() |
| })) |
| defer cst.close() |
| |
| // Our test server above is sending back bogus data after the |
| // response (the "0\r\n\r\n" part), which causes the Transport |
| // code to log spam. Disable keep-alives so we never even try |
| // to reuse the connection. |
| cst.tr.DisableKeepAlives = true |
| |
| res, err := cst.c.Get(cst.ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| if res.Body != NoBody { |
| t.Errorf("Unexpected body on 304 response") |
| } |
| } |
| |
| type funcWriter func([]byte) (int, error) |
| |
| func (f funcWriter) Write(p []byte) (int, error) { return f(p) } |
| |
| type doneContext struct { |
| context.Context |
| err error |
| } |
| |
| func (doneContext) Done() <-chan struct{} { |
| c := make(chan struct{}) |
| close(c) |
| return c |
| } |
| |
| func (d doneContext) Err() error { return d.err } |
| |
| // Issue 25852: Transport should check whether Context is done early. |
| func TestTransportCheckContextDoneEarly(t *testing.T) { |
| tr := &Transport{} |
| req, _ := NewRequest("GET", "http://fake.example/", nil) |
| wantErr := errors.New("some error") |
| req = req.WithContext(doneContext{context.Background(), wantErr}) |
| _, err := tr.RoundTrip(req) |
| if err != wantErr { |
| t.Errorf("error = %v; want %v", err, wantErr) |
| } |
| } |
| |
| // Issue 23399: verify that if a client request times out, the Transport's |
| // conn is closed so that it's not reused. |
| // |
| // This is the test variant that times out before the server replies with |
| // any response headers. |
| func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| inHandler := make(chan net.Conn, 1) |
| handlerReadReturned := make(chan bool, 1) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| inHandler <- conn |
| n, err := conn.Read([]byte{0}) |
| if n != 0 || err != io.EOF { |
| t.Errorf("unexpected Read result: %v, %v", n, err) |
| } |
| handlerReadReturned <- true |
| })) |
| defer cst.close() |
| |
| const timeout = 50 * time.Millisecond |
| cst.c.Timeout = timeout |
| |
| _, err := cst.c.Get(cst.ts.URL) |
| if err == nil { |
| t.Fatal("unexpected Get succeess") |
| } |
| |
| select { |
| case c := <-inHandler: |
| select { |
| case <-handlerReadReturned: |
| // Success. |
| return |
| case <-time.After(5 * time.Second): |
| t.Error("Handler's conn.Read seems to be stuck in Read") |
| c.Close() // close it to unblock Handler |
| } |
| case <-time.After(timeout * 10): |
| // If we didn't get into the Handler in 50ms, that probably means |
| // the builder was just slow and the Get failed in that time |
| // but never made it to the server. That's fine. We'll usually |
| // test the part above on faster machines. |
| t.Skip("skipping test on slow builder") |
| } |
| } |
| |
| // Issue 23399: verify that if a client request times out, the Transport's |
| // conn is closed so that it's not reused. |
| // |
| // This is the test variant that has the server send response headers |
| // first, and time out during the write of the response body. |
| func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| inHandler := make(chan net.Conn, 1) |
| handlerResult := make(chan error, 1) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Content-Length", "100") |
| w.(Flusher).Flush() |
| conn, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| conn.Write([]byte("foo")) |
| inHandler <- conn |
| n, err := conn.Read([]byte{0}) |
| // The error should be io.EOF or "read tcp |
| // 127.0.0.1:35827->127.0.0.1:40290: read: connection |
| // reset by peer" depending on timing. Really we just |
| // care that it returns at all. But if it returns with |
| // data, that's weird. |
| if n != 0 || err == nil { |
| handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) |
| return |
| } |
| handlerResult <- nil |
| })) |
| defer cst.close() |
| |
| // Set Timeout to something very long but non-zero to exercise |
| // the codepaths that check for it. But rather than wait for it to fire |
| // (which would make the test slow), we send on the req.Cancel channel instead, |
| // which happens to exercise the same code paths. |
| cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. |
| req, _ := NewRequest("GET", cst.ts.URL, nil) |
| cancel := make(chan struct{}) |
| req.Cancel = cancel |
| |
| res, err := cst.c.Do(req) |
| if err != nil { |
| select { |
| case <-inHandler: |
| t.Fatalf("Get error: %v", err) |
| default: |
| // Failed before entering handler. Ignore result. |
| t.Skip("skipping test on slow builder") |
| } |
| } |
| |
| close(cancel) |
| got, err := io.ReadAll(res.Body) |
| if err == nil { |
| t.Fatalf("unexpected success; read %q, nil", got) |
| } |
| |
| select { |
| case c := <-inHandler: |
| select { |
| case err := <-handlerResult: |
| if err != nil { |
| t.Errorf("handler: %v", err) |
| } |
| return |
| case <-time.After(5 * time.Second): |
| t.Error("Handler's conn.Read seems to be stuck in Read") |
| c.Close() // close it to unblock Handler |
| } |
| case <-time.After(5 * time.Second): |
| t.Fatal("timeout") |
| } |
| } |
| |
| func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| done := make(chan struct{}) |
| defer close(done) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| defer conn.Close() |
| io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") |
| bs := bufio.NewScanner(conn) |
| bs.Scan() |
| fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) |
| <-done |
| })) |
| defer cst.close() |
| |
| req, _ := NewRequest("GET", cst.ts.URL, nil) |
| req.Header.Set("Upgrade", "foo") |
| req.Header.Set("Connection", "upgrade") |
| res, err := cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if res.StatusCode != 101 { |
| t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) |
| } |
| rwc, ok := res.Body.(io.ReadWriteCloser) |
| if !ok { |
| t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) |
| } |
| defer rwc.Close() |
| bs := bufio.NewScanner(rwc) |
| if !bs.Scan() { |
| t.Fatalf("expected readable input") |
| } |
| if got, want := bs.Text(), "Some buffered data"; got != want { |
| t.Errorf("read %q; want %q", got, want) |
| } |
| io.WriteString(rwc, "echo\n") |
| if !bs.Scan() { |
| t.Fatalf("expected another line") |
| } |
| if got, want := bs.Text(), "ECHO"; got != want { |
| t.Errorf("read %q; want %q", got, want) |
| } |
| } |
| |
| func TestTransportCONNECTBidi(t *testing.T) { |
| defer afterTest(t) |
| const target = "backend:443" |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| if r.Method != "CONNECT" { |
| t.Errorf("unexpected method %q", r.Method) |
| w.WriteHeader(500) |
| return |
| } |
| if r.RequestURI != target { |
| t.Errorf("unexpected CONNECT target %q", r.RequestURI) |
| w.WriteHeader(500) |
| return |
| } |
| nc, brw, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| defer nc.Close() |
| nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) |
| // Switch to a little protocol that capitalize its input lines: |
| for { |
| line, err := brw.ReadString('\n') |
| if err != nil { |
| if err != io.EOF { |
| t.Error(err) |
| } |
| return |
| } |
| io.WriteString(brw, strings.ToUpper(line)) |
| brw.Flush() |
| } |
| })) |
| defer cst.close() |
| pr, pw := io.Pipe() |
| defer pw.Close() |
| req, err := NewRequest("CONNECT", cst.ts.URL, pr) |
| if err != nil { |
| t.Fatal(err) |
| } |
| req.URL.Opaque = target |
| res, err := cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| if res.StatusCode != 200 { |
| t.Fatalf("status code = %d; want 200", res.StatusCode) |
| } |
| br := bufio.NewReader(res.Body) |
| for _, str := range []string{"foo", "bar", "baz"} { |
| fmt.Fprintf(pw, "%s\n", str) |
| got, err := br.ReadString('\n') |
| if err != nil { |
| t.Fatal(err) |
| } |
| got = strings.TrimSpace(got) |
| want := strings.ToUpper(str) |
| if got != want { |
| t.Fatalf("got %q; want %q", got, want) |
| } |
| } |
| } |
| |
| func TestTransportRequestReplayable(t *testing.T) { |
| someBody := io.NopCloser(strings.NewReader("")) |
| tests := []struct { |
| name string |
| req *Request |
| want bool |
| }{ |
| { |
| name: "GET", |
| req: &Request{Method: "GET"}, |
| want: true, |
| }, |
| { |
| name: "GET_http.NoBody", |
| req: &Request{Method: "GET", Body: NoBody}, |
| want: true, |
| }, |
| { |
| name: "GET_body", |
| req: &Request{Method: "GET", Body: someBody}, |
| want: false, |
| }, |
| { |
| name: "POST", |
| req: &Request{Method: "POST"}, |
| want: false, |
| }, |
| { |
| name: "POST_idempotency-key", |
| req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}}, |
| want: true, |
| }, |
| { |
| name: "POST_x-idempotency-key", |
| req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}}, |
| want: true, |
| }, |
| { |
| name: "POST_body", |
| req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody}, |
| want: false, |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| got := tt.req.ExportIsReplayable() |
| if got != tt.want { |
| t.Errorf("replyable = %v; want %v", got, tt.want) |
| } |
| }) |
| } |
| } |
| |
| // testMockTCPConn is a mock TCP connection used to test that |
| // ReadFrom is called when sending the request body. |
| type testMockTCPConn struct { |
| *net.TCPConn |
| |
| ReadFromCalled bool |
| } |
| |
| func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { |
| c.ReadFromCalled = true |
| return c.TCPConn.ReadFrom(r) |
| } |
| |
| func TestTransportRequestWriteRoundTrip(t *testing.T) { |
| nBytes := int64(1 << 10) |
| newFileFunc := func() (r io.Reader, done func(), err error) { |
| f, err := os.CreateTemp("", "net-http-newfilefunc") |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| // Write some bytes to the file to enable reading. |
| if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { |
| return nil, nil, fmt.Errorf("failed to write data to file: %v", err) |
| } |
| if _, err := f.Seek(0, 0); err != nil { |
| return nil, nil, fmt.Errorf("failed to seek to front: %v", err) |
| } |
| |
| done = func() { |
| f.Close() |
| os.Remove(f.Name()) |
| } |
| |
| return f, done, nil |
| } |
| |
| newBufferFunc := func() (io.Reader, func(), error) { |
| return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil |
| } |
| |
| cases := []struct { |
| name string |
| readerFunc func() (io.Reader, func(), error) |
| contentLength int64 |
| expectedReadFrom bool |
| }{ |
| { |
| name: "file, length", |
| readerFunc: newFileFunc, |
| contentLength: nBytes, |
| expectedReadFrom: true, |
| }, |
| { |
| name: "file, no length", |
| readerFunc: newFileFunc, |
| }, |
| { |
| name: "file, negative length", |
| readerFunc: newFileFunc, |
| contentLength: -1, |
| }, |
| { |
| name: "buffer", |
| contentLength: nBytes, |
| readerFunc: newBufferFunc, |
| }, |
| { |
| name: "buffer, no length", |
| readerFunc: newBufferFunc, |
| }, |
| { |
| name: "buffer, length -1", |
| contentLength: -1, |
| readerFunc: newBufferFunc, |
| }, |
| } |
| |
| for _, tc := range cases { |
| t.Run(tc.name, func(t *testing.T) { |
| r, cleanup, err := tc.readerFunc() |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer cleanup() |
| |
| tConn := &testMockTCPConn{} |
| trFunc := func(tr *Transport) { |
| tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { |
| var d net.Dialer |
| conn, err := d.DialContext(ctx, network, addr) |
| if err != nil { |
| return nil, err |
| } |
| |
| tcpConn, ok := conn.(*net.TCPConn) |
| if !ok { |
| return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) |
| } |
| |
| tConn.TCPConn = tcpConn |
| return tConn, nil |
| } |
| } |
| |
| cst := newClientServerTest( |
| t, |
| h1Mode, |
| HandlerFunc(func(w ResponseWriter, r *Request) { |
| io.Copy(io.Discard, r.Body) |
| r.Body.Close() |
| w.WriteHeader(200) |
| }), |
| trFunc, |
| ) |
| defer cst.close() |
| |
| req, err := NewRequest("PUT", cst.ts.URL, r) |
| if err != nil { |
| t.Fatal(err) |
| } |
| req.ContentLength = tc.contentLength |
| req.Header.Set("Content-Type", "application/octet-stream") |
| resp, err := cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer resp.Body.Close() |
| if resp.StatusCode != 200 { |
| t.Fatalf("status code = %d; want 200", resp.StatusCode) |
| } |
| |
| if !tConn.ReadFromCalled && tc.expectedReadFrom { |
| t.Fatalf("did not call ReadFrom") |
| } |
| |
| if tConn.ReadFromCalled && !tc.expectedReadFrom { |
| t.Fatalf("ReadFrom was unexpectedly invoked") |
| } |
| }) |
| } |
| } |
| |
| func TestTransportClone(t *testing.T) { |
| tr := &Transport{ |
| Proxy: func(*Request) (*url.URL, error) { panic("") }, |
| DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, |
| Dial: func(network, addr string) (net.Conn, error) { panic("") }, |
| DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, |
| DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, |
| TLSClientConfig: new(tls.Config), |
| TLSHandshakeTimeout: time.Second, |
| DisableKeepAlives: true, |
| DisableCompression: true, |
| MaxIdleConns: 1, |
| MaxIdleConnsPerHost: 1, |
| MaxConnsPerHost: 1, |
| IdleConnTimeout: time.Second, |
| ResponseHeaderTimeout: time.Second, |
| ExpectContinueTimeout: time.Second, |
| ProxyConnectHeader: Header{}, |
| GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, |
| MaxResponseHeaderBytes: 1, |
| ForceAttemptHTTP2: true, |
| TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ |
| "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, |
| }, |
| ReadBufferSize: 1, |
| WriteBufferSize: 1, |
| } |
| tr2 := tr.Clone() |
| rv := reflect.ValueOf(tr2).Elem() |
| rt := rv.Type() |
| for i := 0; i < rt.NumField(); i++ { |
| sf := rt.Field(i) |
| if !token.IsExported(sf.Name) { |
| continue |
| } |
| if rv.Field(i).IsZero() { |
| t.Errorf("cloned field t2.%s is zero", sf.Name) |
| } |
| } |
| |
| if _, ok := tr2.TLSNextProto["foo"]; !ok { |
| t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") |
| } |
| |
| // But test that a nil TLSNextProto is kept nil: |
| tr = new(Transport) |
| tr2 = tr.Clone() |
| if tr2.TLSNextProto != nil { |
| t.Errorf("Transport.TLSNextProto unexpected non-nil") |
| } |
| } |
| |
| func TestIs408(t *testing.T) { |
| tests := []struct { |
| in string |
| want bool |
| }{ |
| {"HTTP/1.0 408", true}, |
| {"HTTP/1.1 408", true}, |
| {"HTTP/1.8 408", true}, |
| {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. |
| {"HTTP/1.1 408 ", true}, |
| {"HTTP/1.1 40", false}, |
| {"http/1.0 408", false}, |
| {"HTTP/1-1 408", false}, |
| } |
| for _, tt := range tests { |
| if got := Export_is408Message([]byte(tt.in)); got != tt.want { |
| t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) |
| } |
| } |
| } |
| |
| func TestTransportIgnores408(t *testing.T) { |
| // Not parallel. Relies on mutating the log package's global Output. |
| defer log.SetOutput(log.Writer()) |
| |
| var logout bytes.Buffer |
| log.SetOutput(&logout) |
| |
| defer afterTest(t) |
| const target = "backend:443" |
| |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| nc, _, err := w.(Hijacker).Hijack() |
| if err != nil { |
| t.Error(err) |
| return |
| } |
| defer nc.Close() |
| nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) |
| nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail |
| })) |
| defer cst.close() |
| req, err := NewRequest("GET", cst.ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| res, err := cst.c.Do(req) |
| if err != nil { |
| t.Fatal(err) |
| } |
| slurp, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(slurp) != "ok" { |
| t.Fatalf("got %q; want ok", slurp) |
| } |
| |
| t0 := time.Now() |
| for i := 0; i < 50; i++ { |
| time.Sleep(time.Duration(i) * 5 * time.Millisecond) |
| if cst.tr.IdleConnKeyCountForTesting() == 0 { |
| if got := logout.String(); got != "" { |
| t.Fatalf("expected no log output; got: %s", got) |
| } |
| return |
| } |
| } |
| t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) |
| } |
| |
| func TestInvalidHeaderResponse(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { |
| conn, buf, _ := w.(Hijacker).Hijack() |
| buf.Write([]byte("HTTP/1.1 200 OK\r\n" + |
| "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + |
| "Content-Type: text/html; charset=utf-8\r\n" + |
| "Content-Length: 0\r\n" + |
| "Foo : bar\r\n\r\n")) |
| buf.Flush() |
| conn.Close() |
| })) |
| defer cst.close() |
| res, err := cst.c.Get(cst.ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer res.Body.Close() |
| if v := res.Header.Get("Foo"); v != "" { |
| t.Errorf(`unexpected "Foo" header: %q`, v) |
| } |
| if v := res.Header.Get("Foo "); v != "bar" { |
| t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") |
| } |
| } |
| |
| type bodyCloser bool |
| |
| func (bc *bodyCloser) Close() error { |
| *bc = true |
| return nil |
| } |
| func (bc *bodyCloser) Read(b []byte) (n int, err error) { |
| return 0, io.EOF |
| } |
| |
| // Issue 35015: ensure that Transport closes the body on any error |
| // with an invalid request, as promised by Client.Do docs. |
| func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { |
| cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| t.Errorf("Should not have been invoked") |
| })) |
| defer cst.Close() |
| |
| u, _ := url.Parse(cst.URL) |
| |
| tests := []struct { |
| name string |
| req *Request |
| wantErr string |
| }{ |
| { |
| name: "invalid method", |
| req: &Request{ |
| Method: " ", |
| URL: u, |
| }, |
| wantErr: "invalid method", |
| }, |
| { |
| name: "nil URL", |
| req: &Request{ |
| Method: "GET", |
| }, |
| wantErr: "nil Request.URL", |
| }, |
| { |
| name: "invalid header key", |
| req: &Request{ |
| Method: "GET", |
| Header: Header{"💡": {"emoji"}}, |
| URL: u, |
| }, |
| wantErr: "invalid header field name", |
| }, |
| { |
| name: "invalid header value", |
| req: &Request{ |
| Method: "POST", |
| Header: Header{"key": {"\x19"}}, |
| URL: u, |
| }, |
| wantErr: "invalid header field value", |
| }, |
| { |
| name: "non HTTP(s) scheme", |
| req: &Request{ |
| Method: "POST", |
| URL: &url.URL{Scheme: "faux"}, |
| }, |
| wantErr: "unsupported protocol scheme", |
| }, |
| { |
| name: "no Host in URL", |
| req: &Request{ |
| Method: "POST", |
| URL: &url.URL{Scheme: "http"}, |
| }, |
| wantErr: "no Host", |
| }, |
| } |
| |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| var bc bodyCloser |
| req := tt.req |
| req.Body = &bc |
| _, err := DefaultClient.Do(tt.req) |
| if err == nil { |
| t.Fatal("Expected an error") |
| } |
| if !bc { |
| t.Fatal("Expected body to have been closed") |
| } |
| if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { |
| t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w) |
| } |
| }) |
| } |
| } |
| |
| // breakableConn is a net.Conn wrapper with a Write method |
| // that will fail when its brokenState is true. |
| type breakableConn struct { |
| net.Conn |
| *brokenState |
| } |
| |
| type brokenState struct { |
| sync.Mutex |
| broken bool |
| } |
| |
| func (w *breakableConn) Write(b []byte) (n int, err error) { |
| w.Lock() |
| defer w.Unlock() |
| if w.broken { |
| return 0, errors.New("some write error") |
| } |
| return w.Conn.Write(b) |
| } |
| |
| // Issue 34978: don't cache a broken HTTP/2 connection |
| func TestDontCacheBrokenHTTP2Conn(t *testing.T) { |
| cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) |
| defer cst.close() |
| |
| var brokenState brokenState |
| |
| const numReqs = 5 |
| var numDials, gotConns uint32 // atomic |
| |
| cst.tr.Dial = func(netw, addr string) (net.Conn, error) { |
| atomic.AddUint32(&numDials, 1) |
| c, err := net.Dial(netw, addr) |
| if err != nil { |
| t.Errorf("unexpected Dial error: %v", err) |
| return nil, err |
| } |
| return &breakableConn{c, &brokenState}, err |
| } |
| |
| for i := 1; i <= numReqs; i++ { |
| brokenState.Lock() |
| brokenState.broken = false |
| brokenState.Unlock() |
| |
| // doBreak controls whether we break the TCP connection after the TLS |
| // handshake (before the HTTP/2 handshake). We test a few failures |
| // in a row followed by a final success. |
| doBreak := i != numReqs |
| |
| ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ |
| GotConn: func(info httptrace.GotConnInfo) { |
| t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) |
| atomic.AddUint32(&gotConns, 1) |
| }, |
| TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { |
| brokenState.Lock() |
| defer brokenState.Unlock() |
| if doBreak { |
| brokenState.broken = true |
| } |
| }, |
| }) |
| req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| _, err = cst.c.Do(req) |
| if doBreak != (err != nil) { |
| t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) |
| } |
| } |
| if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { |
| t.Errorf("GotConn calls = %v; want %v", got, want) |
| } |
| if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { |
| t.Errorf("Dials = %v; want %v", got, want) |
| } |
| } |
| |
| // Issue 34941 |
| // When the client has too many concurrent requests on a single connection, |
| // http.http2noCachedConnError is reported on multiple requests. There should |
| // only be one decrement regardless of the number of failures. |
| func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { |
| defer afterTest(t) |
| CondSkipHTTP2(t) |
| |
| h := HandlerFunc(func(w ResponseWriter, r *Request) { |
| _, err := w.Write([]byte("foo")) |
| if err != nil { |
| t.Fatalf("Write: %v", err) |
| } |
| }) |
| |
| ts := httptest.NewUnstartedServer(h) |
| ts.EnableHTTP2 = true |
| ts.StartTLS() |
| defer ts.Close() |
| |
| c := ts.Client() |
| tr := c.Transport.(*Transport) |
| tr.MaxConnsPerHost = 1 |
| if err := ExportHttp2ConfigureTransport(tr); err != nil { |
| t.Fatalf("ExportHttp2ConfigureTransport: %v", err) |
| } |
| |
| errCh := make(chan error, 300) |
| doReq := func() { |
| resp, err := c.Get(ts.URL) |
| if err != nil { |
| errCh <- fmt.Errorf("request failed: %v", err) |
| return |
| } |
| defer resp.Body.Close() |
| _, err = io.ReadAll(resp.Body) |
| if err != nil { |
| errCh <- fmt.Errorf("read body failed: %v", err) |
| } |
| } |
| |
| var wg sync.WaitGroup |
| for i := 0; i < 300; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| doReq() |
| }() |
| } |
| wg.Wait() |
| close(errCh) |
| |
| for err := range errCh { |
| t.Errorf("error occurred: %v", err) |
| } |
| } |
| |
| // Issue 36820 |
| // Test that we use the older backward compatible cancellation protocol |
| // when a RoundTripper is registered via RegisterProtocol. |
| func TestAltProtoCancellation(t *testing.T) { |
| defer afterTest(t) |
| tr := &Transport{} |
| c := &Client{ |
| Transport: tr, |
| Timeout: time.Millisecond, |
| } |
| tr.RegisterProtocol("timeout", timeoutProto{}) |
| _, err := c.Get("timeout://bar.com/path") |
| if err == nil { |
| t.Error("request unexpectedly succeeded") |
| } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { |
| t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) |
| } |
| } |
| |
| var timeoutProtoErr = errors.New("canceled as expected") |
| |
| type timeoutProto struct{} |
| |
| func (timeoutProto) RoundTrip(req *Request) (*Response, error) { |
| select { |
| case <-req.Cancel: |
| return nil, timeoutProtoErr |
| case <-time.After(5 * time.Second): |
| return nil, errors.New("request was not canceled") |
| } |
| } |
| |
| type roundTripFunc func(r *Request) (*Response, error) |
| |
| func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } |
| |
| // Issue 32441: body is not reset after ErrSkipAltProtocol |
| func TestIssue32441(t *testing.T) { |
| defer afterTest(t) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| if n, _ := io.Copy(io.Discard, r.Body); n == 0 { |
| t.Error("body length is zero") |
| } |
| })) |
| defer ts.Close() |
| c := ts.Client() |
| c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { |
| // Draining body to trigger failure condition on actual request to server. |
| if n, _ := io.Copy(io.Discard, r.Body); n == 0 { |
| t.Error("body length is zero during round trip") |
| } |
| return nil, ErrSkipAltProtocol |
| })) |
| if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { |
| t.Error(err) |
| } |
| } |
| |
| // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers |
| // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. |
| func TestTransportRejectsSignInContentLength(t *testing.T) { |
| cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { |
| w.Header().Set("Content-Length", "+3") |
| w.Write([]byte("abc")) |
| })) |
| defer cst.Close() |
| |
| c := cst.Client() |
| res, err := c.Get(cst.URL) |
| if err == nil || res != nil { |
| t.Fatal("Expected a non-nil error and a nil http.Response") |
| } |
| if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { |
| t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) |
| } |
| } |
| |
| // dumpConn is a net.Conn which writes to Writer and reads from Reader |
| type dumpConn struct { |
| io.Writer |
| io.Reader |
| } |
| |
| func (c *dumpConn) Close() error { return nil } |
| func (c *dumpConn) LocalAddr() net.Addr { return nil } |
| func (c *dumpConn) RemoteAddr() net.Addr { return nil } |
| func (c *dumpConn) SetDeadline(t time.Time) error { return nil } |
| func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } |
| func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } |
| |
| // delegateReader is a reader that delegates to another reader, |
| // once it arrives on a channel. |
| type delegateReader struct { |
| c chan io.Reader |
| r io.Reader // nil until received from c |
| } |
| |
| func (r *delegateReader) Read(p []byte) (int, error) { |
| if r.r == nil { |
| var ok bool |
| if r.r, ok = <-r.c; !ok { |
| return 0, errors.New("delegate closed") |
| } |
| } |
| return r.r.Read(p) |
| } |
| |
| func testTransportRace(req *Request) { |
| save := req.Body |
| pr, pw := io.Pipe() |
| defer pr.Close() |
| defer pw.Close() |
| dr := &delegateReader{c: make(chan io.Reader)} |
| |
| t := &Transport{ |
| Dial: func(net, addr string) (net.Conn, error) { |
| return &dumpConn{pw, dr}, nil |
| }, |
| } |
| defer t.CloseIdleConnections() |
| |
| quitReadCh := make(chan struct{}) |
| // Wait for the request before replying with a dummy response: |
| go func() { |
| defer close(quitReadCh) |
| |
| req, err := ReadRequest(bufio.NewReader(pr)) |
| if err == nil { |
| // Ensure all the body is read; otherwise |
| // we'll get a partial dump. |
| io.Copy(io.Discard, req.Body) |
| req.Body.Close() |
| } |
| select { |
| case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): |
| case quitReadCh <- struct{}{}: |
| // Ensure delegate is closed so Read doesn't block forever. |
| close(dr.c) |
| } |
| }() |
| |
| t.RoundTrip(req) |
| |
| // Ensure the reader returns before we reset req.Body to prevent |
| // a data race on req.Body. |
| pw.Close() |
| <-quitReadCh |
| |
| req.Body = save |
| } |
| |
| // Issue 37669 |
| // Test that a cancellation doesn't result in a data race due to the writeLoop |
| // goroutine being left running, if the caller mutates the processed Request |
| // upon completion. |
| func TestErrorWriteLoopRace(t *testing.T) { |
| if testing.Short() { |
| return |
| } |
| t.Parallel() |
| for i := 0; i < 1000; i++ { |
| delay := time.Duration(mrand.Intn(5)) * time.Millisecond |
| ctx, cancel := context.WithTimeout(context.Background(), delay) |
| defer cancel() |
| |
| r := bytes.NewBuffer(make([]byte, 10000)) |
| req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| testTransportRace(req) |
| } |
| } |
| |
| // Issue 41600 |
| // Test that a new request which uses the connection of an active request |
| // cannot cause it to be canceled as well. |
| func TestCancelRequestWhenSharingConnection(t *testing.T) { |
| reqc := make(chan chan struct{}, 2) |
| ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) { |
| ch := make(chan struct{}, 1) |
| reqc <- ch |
| <-ch |
| w.Header().Add("Content-Length", "0") |
| })) |
| defer ts.Close() |
| |
| client := ts.Client() |
| transport := client.Transport.(*Transport) |
| transport.MaxIdleConns = 1 |
| transport.MaxConnsPerHost = 1 |
| |
| var wg sync.WaitGroup |
| |
| wg.Add(1) |
| putidlec := make(chan chan struct{}) |
| go func() { |
| defer wg.Done() |
| ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ |
| PutIdleConn: func(error) { |
| // Signal that the idle conn has been returned to the pool, |
| // and wait for the order to proceed. |
| ch := make(chan struct{}) |
| putidlec <- ch |
| <-ch |
| }, |
| }) |
| req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) |
| res, err := client.Do(req) |
| if err == nil { |
| res.Body.Close() |
| } |
| if err != nil { |
| t.Errorf("request 1: got err %v, want nil", err) |
| } |
| }() |
| |
| // Wait for the first request to receive a response and return the |
| // connection to the idle pool. |
| r1c := <-reqc |
| close(r1c) |
| idlec := <-putidlec |
| |
| wg.Add(1) |
| cancelctx, cancel := context.WithCancel(context.Background()) |
| go func() { |
| defer wg.Done() |
| req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil) |
| res, err := client.Do(req) |
| if err == nil { |
| res.Body.Close() |
| } |
| if !errors.Is(err, context.Canceled) { |
| t.Errorf("request 2: got err %v, want Canceled", err) |
| } |
| }() |
| |
| // Wait for the second request to arrive at the server, and then cancel |
| // the request context. |
| r2c := <-reqc |
| cancel() |
| |
| // Give the cancelation a moment to take effect, and then unblock the first request. |
| time.Sleep(1 * time.Millisecond) |
| close(idlec) |
| |
| close(r2c) |
| wg.Wait() |
| } |
| |
| func TestHandlerAbortRacesBodyRead(t *testing.T) { |
| setParallel(t) |
| defer afterTest(t) |
| |
| ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { |
| go io.Copy(io.Discard, req.Body) |
| panic(ErrAbortHandler) |
| })) |
| defer ts.Close() |
| |
| var wg sync.WaitGroup |
| for i := 0; i < 2; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| for j := 0; j < 10; j++ { |
| const reqLen = 6 * 1024 * 1024 |
| req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) |
| req.ContentLength = reqLen |
| resp, _ := ts.Client().Transport.RoundTrip(req) |
| if resp != nil { |
| resp.Body.Close() |
| } |
| } |
| }() |
| } |
| wg.Wait() |
| } |