blob: d9206a30465d233f777d63a8c93b310951d868e3 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestTimeout(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
http.Error(w, "bad", http.StatusInternalServerError)
return
default:
}
fmt.Fprintln(w, "Hello!")
})
mux := http.NewServeMux()
mux.Handle("/A", Timeout(5*time.Second)(handler))
mux.Handle("/B", Timeout(0)(handler))
mux.Handle("/C", handler)
ts := httptest.NewServer(mux)
defer ts.Close()
tests := []struct {
label string
url string
wantStatus int
}{
{
label: "normal timed request",
url: fmt.Sprintf("%s/A", ts.URL),
wantStatus: http.StatusOK,
},
{
label: "timed-out request",
url: fmt.Sprintf("%s/B", ts.URL),
wantStatus: http.StatusInternalServerError,
},
{
label: "request with no timeout",
url: fmt.Sprintf("%s/C", ts.URL),
wantStatus: http.StatusOK,
},
}
for _, test := range tests {
t.Run(test.label, func(t *testing.T) {
resp, err := ts.Client().Get(test.url)
if err != nil {
t.Errorf("GET %s got error %v, want nil", test.url, err)
}
defer resp.Body.Close()
if resp.StatusCode != test.wantStatus {
t.Errorf("GET %s returned status %d, want %d", test.url, resp.StatusCode, test.wantStatus)
}
})
}
}