blob: 6e0c8731660a8c157b4b98bd0332aebecb8cd379 [file] [log] [blame]
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
var testHosts = map[string]string{
"": "foo.example.test",
"dev.example.test": "foo-dev.example.test",
"staging.example.test": "foo-staging.example.test",
}
func TestRedirect(t *testing.T) {
tests := []struct {
desc string
target string
hosts map[string]string
want string
wantCode int
}{
{
desc: "basic redirect",
target: "/",
hosts: testHosts,
want: "https://foo.example.test/",
wantCode: http.StatusFound,
},
{
desc: "redirect keeps query and path",
target: "/github.com/golang/glog?tab=overview",
hosts: testHosts,
want: "https://foo.example.test/github.com/golang/glog?tab=overview",
wantCode: http.StatusFound,
},
{
desc: "redirects to the correct host",
target: "https://dev.example.test/",
hosts: testHosts,
want: "https://foo-dev.example.test/",
wantCode: http.StatusFound,
},
{
desc: "renders 404 if hosts are missing",
target: "https://dev.example.test/",
hosts: nil,
wantCode: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.target, nil)
w := httptest.NewRecorder()
redirectHosts(tt.hosts).ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != tt.wantCode {
t.Errorf("w.Result().StatusCode = %v, wanted %v", resp.StatusCode, tt.wantCode)
}
l, err := resp.Location()
if resp.StatusCode == http.StatusFound && (l == nil || l.String() != tt.want || err != nil) {
t.Errorf("resp.Location() = %v, %v, wanted %v, no error", l, err, tt.want)
}
})
}
}