net/http: make Server validate Host headers
Fixes #11206 (that we accept invalid bytes)
Fixes #13624 (that we don't require a Host header in HTTP/1.1 per spec)
Change-Id: I4138281d513998789163237e83bb893aeda43336
Reviewed-on: https://go-review.googlesource.com/17892
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/net/http/request.go b/src/net/http/request.go
index 9f74042..01575f3 100644
--- a/src/net/http/request.go
+++ b/src/net/http/request.go
@@ -689,8 +689,9 @@
}
// ReadRequest reads and parses an incoming request from b.
-func ReadRequest(b *bufio.Reader) (req *Request, err error) {
+func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) }
+func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
tp := newTextprotoReader(b)
req = new(Request)
@@ -757,7 +758,9 @@
if req.Host == "" {
req.Host = req.Header.get("Host")
}
- delete(req.Header, "Host")
+ if deleteHostHeader {
+ delete(req.Header, "Host")
+ }
fixPragmaCacheControl(req.Header)
@@ -1060,3 +1063,59 @@
r.Method == "OPTIONS" ||
r.Method == "TRACE")
}
+
+func validHostHeader(h string) bool {
+ // The latests spec is actually this:
+ //
+ // http://tools.ietf.org/html/rfc7230#section-5.4
+ // Host = uri-host [ ":" port ]
+ //
+ // Where uri-host is:
+ // http://tools.ietf.org/html/rfc3986#section-3.2.2
+ //
+ // But we're going to be much more lenient for now and just
+ // search for any byte that's not a valid byte in any of those
+ // expressions.
+ for i := 0; i < len(h); i++ {
+ if !validHostByte[h[i]] {
+ return false
+ }
+ }
+ return true
+}
+
+// See the validHostHeader comment.
+var validHostByte = [256]bool{
+ '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
+ '8': true, '9': true,
+
+ 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
+ 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
+ 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
+ 'y': true, 'z': true,
+
+ 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
+ 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
+ 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
+ 'Y': true, 'Z': true,
+
+ '!': true, // sub-delims
+ '$': true, // sub-delims
+ '%': true, // pct-encoded (and used in IPv6 zones)
+ '&': true, // sub-delims
+ '(': true, // sub-delims
+ ')': true, // sub-delims
+ '*': true, // sub-delims
+ '+': true, // sub-delims
+ ',': true, // sub-delims
+ '-': true, // unreserved
+ '.': true, // unreserved
+ ':': true, // IPv6address + Host expression's optional port
+ ';': true, // sub-delims
+ '=': true, // sub-delims
+ '[': true,
+ '\'': true, // sub-delims
+ ']': true,
+ '_': true, // unreserved
+ '~': true, // unreserved
+}
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index 3e84f2e..31ba06a 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -2201,7 +2201,7 @@
// buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) {
conn := new(testConn)
- conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+ conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -2934,9 +2934,9 @@
"GET / HTTP/1.0",
"GET /header HTTP/1.0",
"GET /more HTTP/1.0",
- "GET / HTTP/1.1",
- "GET /header HTTP/1.1",
- "GET /more HTTP/1.1",
+ "GET / HTTP/1.1\nHost: foo",
+ "GET /header HTTP/1.1\nHost: foo",
+ "GET /more HTTP/1.1\nHost: foo",
} {
got := ht.rawResponse(req)
wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
@@ -2957,7 +2957,7 @@
w.Header().Set("Content-Type", "foo/bar")
w.WriteHeader(204)
}))
- got := ht.rawResponse("GET / HTTP/1.1")
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
if !strings.Contains(got, "Content-Type: foo/bar") {
t.Errorf("Response = %q; want Content-Type: foo/bar", got)
}
@@ -3628,6 +3628,54 @@
}
}
+// Test that we validate the Host header.
+func TestServerValidatesHostHeader(t *testing.T) {
+ tests := []struct {
+ proto string
+ host string
+ want int
+ }{
+ {"HTTP/1.1", "", 400},
+ {"HTTP/1.1", "Host: \r\n", 200},
+ {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+ {"HTTP/1.1", "Host: foo.com\r\n", 200},
+ {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
+ {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
+ {"HTTP/1.1", "Host: ::1\r\n", 200},
+ {"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it
+ {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
+ {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
+ {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+ {"HTTP/1.1", "Host: \x06\r\n", 400},
+ {"HTTP/1.1", "Host: \xff\r\n", 400},
+ {"HTTP/1.1", "Host: {\r\n", 400},
+ {"HTTP/1.1", "Host: }\r\n", 400},
+ {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
+
+ // HTTP/1.0 can lack a host header, but if present
+ // must play by the rules too:
+ {"HTTP/1.0", "", 200},
+ {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
+ {"HTTP/1.0", "Host: \xff\r\n", 400},
+ }
+ for _, tt := range tests {
+ conn := &testConn{closec: make(chan bool)}
+ io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n")
+
+ ln := &oneConnListener{conn}
+ go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+ <-conn.closec
+ res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+ if err != nil {
+ t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
+ }
+ }
+}
+
func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs()
b.StopTimer()
diff --git a/src/net/http/server.go b/src/net/http/server.go
index cd5f9cf..a00085c 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -686,7 +686,7 @@
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek))
}
- req, err := ReadRequest(c.bufr)
+ req, err := readRequest(c.bufr, false)
c.mu.Unlock()
if err != nil {
if c.r.hitReadLimit() {
@@ -697,6 +697,18 @@
c.lastMethod = req.Method
c.r.setInfiniteReadLimit()
+ hosts, haveHost := req.Header["Host"]
+ if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) {
+ return nil, badRequestError("missing required Host header")
+ }
+ if len(hosts) > 1 {
+ return nil, badRequestError("too many Host headers")
+ }
+ if len(hosts) == 1 && !validHostHeader(hosts[0]) {
+ return nil, badRequestError("malformed Host header")
+ }
+ delete(req.Header, "Host")
+
req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState
if body, ok := req.Body.(*body); ok {
@@ -1334,6 +1346,13 @@
}
}
+// badRequestError is a literal string (used by in the server in HTML,
+// unescaped) to tell the user why their request was bad. It should
+// be plain text without user info or other embeddded errors.
+type badRequestError string
+
+func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
+
// Serve a new connection.
func (c *conn) serve() {
c.remoteAddr = c.rwc.RemoteAddr().String()
@@ -1399,7 +1418,11 @@
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return // don't reply
}
- io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request")
+ var publicErr string
+ if v, ok := err.(badRequestError); ok {
+ publicErr = ": " + string(v)
+ }
+ io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr)
return
}