blob: a29e3ae16dff6b432decb479ac9a23ee9f74e8ca [file] [log] [blame] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http_test
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// httptestNewRequest works around https://go.dev/issue/73151.
func httptestNewRequest(method, target string) *http.Request {
req := httptest.NewRequest(method, target, nil)
req.URL.Scheme = ""
req.URL.Host = ""
return req
}
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
func TestCrossOriginProtectionSecFetchSite(t *testing.T) {
protection := http.NewCrossOriginProtection()
handler := protection.Handler(okHandler)
tests := []struct {
name string
method string
secFetchSite string
origin string
expectedStatus int
}{
{"same-origin allowed", "POST", "same-origin", "", http.StatusOK},
{"none allowed", "POST", "none", "", http.StatusOK},
{"cross-site blocked", "POST", "cross-site", "", http.StatusForbidden},
{"same-site blocked", "POST", "same-site", "", http.StatusForbidden},
{"no header with no origin", "POST", "", "", http.StatusOK},
{"no header with matching origin", "POST", "", "https://example.com", http.StatusOK},
{"no header with mismatched origin", "POST", "", "https://attacker.example", http.StatusForbidden},
{"no header with null origin", "POST", "", "null", http.StatusForbidden},
{"GET allowed", "GET", "cross-site", "", http.StatusOK},
{"HEAD allowed", "HEAD", "cross-site", "", http.StatusOK},
{"OPTIONS allowed", "OPTIONS", "cross-site", "", http.StatusOK},
{"PUT blocked", "PUT", "cross-site", "", http.StatusForbidden},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptestNewRequest(tc.method, "https://example.com/")
if tc.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
}
if tc.origin != "" {
req.Header.Set("Origin", tc.origin)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tc.expectedStatus {
t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
}
})
}
}
func TestCrossOriginProtectionTrustedOriginBypass(t *testing.T) {
protection := http.NewCrossOriginProtection()
err := protection.AddTrustedOrigin("https://trusted.example")
if err != nil {
t.Fatalf("AddTrustedOrigin: %v", err)
}
handler := protection.Handler(okHandler)
tests := []struct {
name string
origin string
secFetchSite string
expectedStatus int
}{
{"trusted origin without sec-fetch-site", "https://trusted.example", "", http.StatusOK},
{"trusted origin with cross-site", "https://trusted.example", "cross-site", http.StatusOK},
{"untrusted origin without sec-fetch-site", "https://attacker.example", "", http.StatusForbidden},
{"untrusted origin with cross-site", "https://attacker.example", "cross-site", http.StatusForbidden},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptestNewRequest("POST", "https://example.com/")
req.Header.Set("Origin", tc.origin)
if tc.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tc.expectedStatus {
t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
}
})
}
}
func TestCrossOriginProtectionPatternBypass(t *testing.T) {
protection := http.NewCrossOriginProtection()
protection.AddInsecureBypassPattern("/bypass/")
protection.AddInsecureBypassPattern("/only/{foo}")
protection.AddInsecureBypassPattern("/no-trailing")
protection.AddInsecureBypassPattern("/yes-trailing/")
protection.AddInsecureBypassPattern("PUT /put-only/")
protection.AddInsecureBypassPattern("GET /get-only/")
protection.AddInsecureBypassPattern("POST /post-only/")
handler := protection.Handler(okHandler)
tests := []struct {
name string
path string
secFetchSite string
expectedStatus int
}{
{"bypass path without sec-fetch-site", "/bypass/", "", http.StatusOK},
{"bypass path with cross-site", "/bypass/", "cross-site", http.StatusOK},
{"non-bypass path without sec-fetch-site", "/api/", "", http.StatusForbidden},
{"non-bypass path with cross-site", "/api/", "cross-site", http.StatusForbidden},
{"redirect to bypass path without ..", "/foo/../bypass/bar", "", http.StatusForbidden},
{"redirect to bypass path with trailing slash", "/bypass", "", http.StatusForbidden},
{"redirect to non-bypass path with ..", "/foo/../api/bar", "", http.StatusForbidden},
{"redirect to non-bypass path with trailing slash", "/api", "", http.StatusForbidden},
{"wildcard bypass", "/only/123", "", http.StatusOK},
{"non-wildcard", "/only/123/foo", "", http.StatusForbidden},
// https://go.dev/issue/75054
{"no trailing slash exact match", "/no-trailing", "", http.StatusOK},
{"no trailing slash with slash", "/no-trailing/", "", http.StatusForbidden},
{"yes trailing slash exact match", "/yes-trailing/", "", http.StatusOK},
{"yes trailing slash without slash", "/yes-trailing", "", http.StatusForbidden},
{"method-specific hit", "/post-only/", "", http.StatusOK},
{"method-specific miss (PUT)", "/put-only/", "", http.StatusForbidden},
{"method-specific miss (GET)", "/get-only/", "", http.StatusForbidden},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptestNewRequest("POST", "https://example.com"+tc.path)
req.Header.Set("Origin", "https://attacker.example")
if tc.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != tc.expectedStatus {
t.Errorf("got status %d, want %d", w.Code, tc.expectedStatus)
}
})
}
}
func TestCrossOriginProtectionSetDenyHandler(t *testing.T) {
protection := http.NewCrossOriginProtection()
handler := protection.Handler(okHandler)
req := httptestNewRequest("POST", "https://example.com/")
req.Header.Set("Sec-Fetch-Site", "cross-site")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
}
customErrHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
io.WriteString(w, "custom error")
})
protection.SetDenyHandler(customErrHandler)
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTeapot {
t.Errorf("got status %d, want %d", w.Code, http.StatusTeapot)
}
if !strings.Contains(w.Body.String(), "custom error") {
t.Errorf("expected custom error message, got: %q", w.Body.String())
}
req = httptestNewRequest("GET", "https://example.com/")
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("got status %d, want %d", w.Code, http.StatusOK)
}
protection.SetDenyHandler(nil)
req = httptestNewRequest("POST", "https://example.com/")
req.Header.Set("Sec-Fetch-Site", "cross-site")
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
}
}
func TestCrossOriginProtectionAddTrustedOriginErrors(t *testing.T) {
protection := http.NewCrossOriginProtection()
tests := []struct {
name string
origin string
wantErr bool
}{
{"valid origin", "https://example.com", false},
{"valid origin with port", "https://example.com:8080", false},
{"http origin", "http://example.com", false},
{"missing scheme", "example.com", true},
{"missing host", "https://", true},
{"trailing slash", "https://example.com/", true},
{"with path", "https://example.com/path", true},
{"with query", "https://example.com?query=value", true},
{"with fragment", "https://example.com#fragment", true},
{"invalid url", "https://ex ample.com", true},
{"empty string", "", true},
{"null", "null", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := protection.AddTrustedOrigin(tc.origin)
if (err != nil) != tc.wantErr {
t.Errorf("AddTrustedOrigin(%q) error = %v, wantErr %v", tc.origin, err, tc.wantErr)
}
})
}
}
func TestCrossOriginProtectionAddingBypassesConcurrently(t *testing.T) {
protection := http.NewCrossOriginProtection()
handler := protection.Handler(okHandler)
req := httptestNewRequest("POST", "https://example.com/")
req.Header.Set("Origin", "https://concurrent.example")
req.Header.Set("Sec-Fetch-Site", "cross-site")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("got status %d, want %d", w.Code, http.StatusForbidden)
}
start := make(chan struct{})
done := make(chan struct{})
go func() {
close(start)
defer close(done)
for range 10 {
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
}()
// Add bypasses while the requests are in flight.
<-start
protection.AddTrustedOrigin("https://concurrent.example")
protection.AddInsecureBypassPattern("/foo/")
<-done
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("After concurrent bypass addition, got status %d, want %d", w.Code, http.StatusOK)
}
}
func TestCrossOriginProtectionServer(t *testing.T) {
protection := http.NewCrossOriginProtection()
protection.AddTrustedOrigin("https://trusted.example")
protection.AddInsecureBypassPattern("/bypass/")
handler := protection.Handler(okHandler)
ts := httptest.NewServer(handler)
defer ts.Close()
tests := []struct {
name string
method string
url string
origin string
secFetchSite string
expectedStatus int
}{
{"cross-site", "POST", ts.URL, "https://attacker.example", "cross-site", http.StatusForbidden},
{"same-origin", "POST", ts.URL, "", "same-origin", http.StatusOK},
{"origin matches host", "POST", ts.URL, ts.URL, "", http.StatusOK},
{"trusted origin", "POST", ts.URL, "https://trusted.example", "", http.StatusOK},
{"untrusted origin", "POST", ts.URL, "https://attacker.example", "", http.StatusForbidden},
{"bypass path", "POST", ts.URL + "/bypass/", "https://attacker.example", "", http.StatusOK},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(tc.method, tc.url, nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
if tc.origin != "" {
req.Header.Set("Origin", tc.origin)
}
if tc.secFetchSite != "" {
req.Header.Set("Sec-Fetch-Site", tc.secFetchSite)
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Do: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.expectedStatus {
t.Errorf("got status %d, want %d", resp.StatusCode, tc.expectedStatus)
}
})
}
}