| // Copyright 2012 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package httptest |
| |
| import ( |
| "bufio" |
| "io" |
| "net" |
| "net/http" |
| "sync" |
| "testing" |
| ) |
| |
| type newServerFunc func(http.Handler) *Server |
| |
| var newServers = map[string]newServerFunc{ |
| "NewServer": NewServer, |
| "NewTLSServer": NewTLSServer, |
| |
| // The manual variants of newServer create a Server manually by only filling |
| // in the exported fields of Server. |
| "NewServerManual": func(h http.Handler) *Server { |
| ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} |
| ts.Start() |
| return ts |
| }, |
| "NewTLSServerManual": func(h http.Handler) *Server { |
| ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} |
| ts.StartTLS() |
| return ts |
| }, |
| } |
| |
| func TestServer(t *testing.T) { |
| for _, name := range []string{"NewServer", "NewServerManual"} { |
| t.Run(name, func(t *testing.T) { |
| newServer := newServers[name] |
| t.Run("Server", func(t *testing.T) { testServer(t, newServer) }) |
| t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) }) |
| t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) }) |
| t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) }) |
| t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) }) |
| }) |
| } |
| for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} { |
| t.Run(name, func(t *testing.T) { |
| newServer := newServers[name] |
| t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) }) |
| t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) }) |
| }) |
| } |
| } |
| |
| func testServer(t *testing.T, newServer newServerFunc) { |
| ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Write([]byte("hello")) |
| })) |
| defer ts.Close() |
| res, err := http.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| got, err := io.ReadAll(res.Body) |
| res.Body.Close() |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(got) != "hello" { |
| t.Errorf("got %q, want hello", string(got)) |
| } |
| } |
| |
| // Issue 12781 |
| func testGetAfterClose(t *testing.T, newServer newServerFunc) { |
| ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Write([]byte("hello")) |
| })) |
| |
| res, err := http.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| got, err := io.ReadAll(res.Body) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(got) != "hello" { |
| t.Fatalf("got %q, want hello", string(got)) |
| } |
| |
| ts.Close() |
| |
| res, err = http.Get(ts.URL) |
| if err == nil { |
| body, _ := io.ReadAll(res.Body) |
| t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) |
| } |
| } |
| |
| func testServerCloseBlocking(t *testing.T, newServer newServerFunc) { |
| ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Write([]byte("hello")) |
| })) |
| dial := func() net.Conn { |
| c, err := net.Dial("tcp", ts.Listener.Addr().String()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| return c |
| } |
| |
| // Keep one connection in StateNew (connected, but not sending anything) |
| cnew := dial() |
| defer cnew.Close() |
| |
| // Keep one connection in StateIdle (idle after a request) |
| cidle := dial() |
| defer cidle.Close() |
| cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) |
| _, err := http.ReadResponse(bufio.NewReader(cidle), nil) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| ts.Close() // test we don't hang here forever. |
| } |
| |
| // Issue 14290 |
| func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) { |
| var s *Server |
| s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| s.CloseClientConnections() |
| })) |
| defer s.Close() |
| res, err := http.Get(s.URL) |
| if err == nil { |
| res.Body.Close() |
| t.Fatalf("Unexpected response: %#v", res) |
| } |
| } |
| |
| // Tests that the Server.Client method works and returns an http.Client that can hit |
| // NewTLSServer without cert warnings. |
| func testServerClient(t *testing.T, newTLSServer newServerFunc) { |
| ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Write([]byte("hello")) |
| })) |
| defer ts.Close() |
| client := ts.Client() |
| res, err := client.Get(ts.URL) |
| if err != nil { |
| t.Fatal(err) |
| } |
| got, err := io.ReadAll(res.Body) |
| res.Body.Close() |
| if err != nil { |
| t.Fatal(err) |
| } |
| if string(got) != "hello" { |
| t.Errorf("got %q, want hello", string(got)) |
| } |
| } |
| |
| // Tests that the Server.Client.Transport interface is implemented |
| // by a *http.Transport. |
| func testServerClientTransportType(t *testing.T, newServer newServerFunc) { |
| ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| })) |
| defer ts.Close() |
| client := ts.Client() |
| if _, ok := client.Transport.(*http.Transport); !ok { |
| t.Errorf("got %T, want *http.Transport", client.Transport) |
| } |
| } |
| |
| // Tests that the TLS Server.Client.Transport interface is implemented |
| // by a *http.Transport. |
| func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) { |
| ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| })) |
| defer ts.Close() |
| client := ts.Client() |
| if _, ok := client.Transport.(*http.Transport); !ok { |
| t.Errorf("got %T, want *http.Transport", client.Transport) |
| } |
| } |
| |
| type onlyCloseListener struct { |
| net.Listener |
| } |
| |
| func (onlyCloseListener) Close() error { return nil } |
| |
| // Issue 19729: panic in Server.Close for values created directly |
| // without a constructor (so the unexported client field is nil). |
| func TestServerZeroValueClose(t *testing.T) { |
| ts := &Server{ |
| Listener: onlyCloseListener{}, |
| Config: &http.Server{}, |
| } |
| |
| ts.Close() // tests that it doesn't panic |
| } |
| |
| // Issue 51799: test hijacking a connection and then closing it |
| // concurrently with closing the server. |
| func TestCloseHijackedConnection(t *testing.T) { |
| hijacked := make(chan net.Conn) |
| ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| defer close(hijacked) |
| hj, ok := w.(http.Hijacker) |
| if !ok { |
| t.Fatal("failed to hijack") |
| } |
| c, _, err := hj.Hijack() |
| if err != nil { |
| t.Fatal(err) |
| } |
| hijacked <- c |
| })) |
| |
| var wg sync.WaitGroup |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| req, err := http.NewRequest("GET", ts.URL, nil) |
| if err != nil { |
| t.Log(err) |
| } |
| // Use a client not associated with the Server. |
| var c http.Client |
| resp, err := c.Do(req) |
| if err != nil { |
| t.Log(err) |
| return |
| } |
| resp.Body.Close() |
| }() |
| |
| wg.Add(1) |
| conn := <-hijacked |
| go func(conn net.Conn) { |
| defer wg.Done() |
| // Close the connection and then inform the Server that |
| // we closed it. |
| conn.Close() |
| ts.Config.ConnState(conn, http.StateClosed) |
| }(conn) |
| |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| ts.Close() |
| }() |
| wg.Wait() |
| } |
| |
| func TestTLSServerWithHTTP2(t *testing.T) { |
| modes := []struct { |
| name string |
| wantProto string |
| }{ |
| {"http1", "HTTP/1.1"}, |
| {"http2", "HTTP/2.0"}, |
| } |
| |
| for _, tt := range modes { |
| t.Run(tt.name, func(t *testing.T) { |
| cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("X-Proto", r.Proto) |
| })) |
| |
| switch tt.name { |
| case "http2": |
| cst.EnableHTTP2 = true |
| cst.StartTLS() |
| default: |
| cst.Start() |
| } |
| |
| defer cst.Close() |
| |
| res, err := cst.Client().Get(cst.URL) |
| if err != nil { |
| t.Fatalf("Failed to make request: %v", err) |
| } |
| if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w { |
| t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w) |
| } |
| }) |
| } |
| } |