blob: 3b8d82d5a26d7033a83090d307f78ffc56aeae04 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestAcceptMethods(t *testing.T) {
mw := AcceptMethods("GET", "HEAD")
var called bool
ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})))
defer ts.Close()
c := ts.Client()
for _, test := range []struct {
method string
want bool
}{
{"GET", true},
{"HEAD", true},
{"POST", false},
{"DELETE", false},
} {
called = false
req, err := http.NewRequest(test.method, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
if called != test.want {
t.Errorf("%s called: got %t, want %t", test.method, called, test.want)
continue
}
var wantCode int
if called {
wantCode = http.StatusOK
} else {
wantCode = http.StatusMethodNotAllowed
}
if got := res.StatusCode; got != wantCode {
t.Errorf("%s code: got %d, want %d", test.method, got, wantCode)
}
}
}