net/http: make Server validate Host headers

Fixes #11206 (that we accept invalid bytes)
Fixes #13624 (that we don't require a Host header in HTTP/1.1 per spec)

Change-Id: I4138281d513998789163237e83bb893aeda43336
Reviewed-on: https://go-review.googlesource.com/17892
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/src/net/http/request.go b/src/net/http/request.go
index 9f74042..01575f3 100644
--- a/src/net/http/request.go
+++ b/src/net/http/request.go
@@ -689,8 +689,9 @@
 }
 
 // ReadRequest reads and parses an incoming request from b.
-func ReadRequest(b *bufio.Reader) (req *Request, err error) {
+func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) }
 
+func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
 	tp := newTextprotoReader(b)
 	req = new(Request)
 
@@ -757,7 +758,9 @@
 	if req.Host == "" {
 		req.Host = req.Header.get("Host")
 	}
-	delete(req.Header, "Host")
+	if deleteHostHeader {
+		delete(req.Header, "Host")
+	}
 
 	fixPragmaCacheControl(req.Header)
 
@@ -1060,3 +1063,59 @@
 			r.Method == "OPTIONS" ||
 			r.Method == "TRACE")
 }
+
+func validHostHeader(h string) bool {
+	// The latests spec is actually this:
+	//
+	// http://tools.ietf.org/html/rfc7230#section-5.4
+	//     Host = uri-host [ ":" port ]
+	//
+	// Where uri-host is:
+	//     http://tools.ietf.org/html/rfc3986#section-3.2.2
+	//
+	// But we're going to be much more lenient for now and just
+	// search for any byte that's not a valid byte in any of those
+	// expressions.
+	for i := 0; i < len(h); i++ {
+		if !validHostByte[h[i]] {
+			return false
+		}
+	}
+	return true
+}
+
+// See the validHostHeader comment.
+var validHostByte = [256]bool{
+	'0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
+	'8': true, '9': true,
+
+	'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
+	'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
+	'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
+	'y': true, 'z': true,
+
+	'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
+	'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
+	'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
+	'Y': true, 'Z': true,
+
+	'!':  true, // sub-delims
+	'$':  true, // sub-delims
+	'%':  true, // pct-encoded (and used in IPv6 zones)
+	'&':  true, // sub-delims
+	'(':  true, // sub-delims
+	')':  true, // sub-delims
+	'*':  true, // sub-delims
+	'+':  true, // sub-delims
+	',':  true, // sub-delims
+	'-':  true, // unreserved
+	'.':  true, // unreserved
+	':':  true, // IPv6address + Host expression's optional port
+	';':  true, // sub-delims
+	'=':  true, // sub-delims
+	'[':  true,
+	'\'': true, // sub-delims
+	']':  true,
+	'_':  true, // unreserved
+	'~':  true, // unreserved
+}
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index 3e84f2e..31ba06a 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -2201,7 +2201,7 @@
 // buffered before chunk headers are added, not after chunk headers.
 func TestServerBufferedChunking(t *testing.T) {
 	conn := new(testConn)
-	conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+	conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
 	conn.closec = make(chan bool, 1)
 	ls := &oneConnListener{conn}
 	go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -2934,9 +2934,9 @@
 			"GET / HTTP/1.0",
 			"GET /header HTTP/1.0",
 			"GET /more HTTP/1.0",
-			"GET / HTTP/1.1",
-			"GET /header HTTP/1.1",
-			"GET /more HTTP/1.1",
+			"GET / HTTP/1.1\nHost: foo",
+			"GET /header HTTP/1.1\nHost: foo",
+			"GET /more HTTP/1.1\nHost: foo",
 		} {
 			got := ht.rawResponse(req)
 			wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
@@ -2957,7 +2957,7 @@
 		w.Header().Set("Content-Type", "foo/bar")
 		w.WriteHeader(204)
 	}))
-	got := ht.rawResponse("GET / HTTP/1.1")
+	got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
 	if !strings.Contains(got, "Content-Type: foo/bar") {
 		t.Errorf("Response = %q; want Content-Type: foo/bar", got)
 	}
@@ -3628,6 +3628,54 @@
 	}
 }
 
+// Test that we validate the Host header.
+func TestServerValidatesHostHeader(t *testing.T) {
+	tests := []struct {
+		proto string
+		host  string
+		want  int
+	}{
+		{"HTTP/1.1", "", 400},
+		{"HTTP/1.1", "Host: \r\n", 200},
+		{"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+		{"HTTP/1.1", "Host: foo.com\r\n", 200},
+		{"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
+		{"HTTP/1.1", "Host: foo.com:80\r\n", 200},
+		{"HTTP/1.1", "Host: ::1\r\n", 200},
+		{"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it
+		{"HTTP/1.1", "Host: [::1]:80\r\n", 200},
+		{"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
+		{"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+		{"HTTP/1.1", "Host: \x06\r\n", 400},
+		{"HTTP/1.1", "Host: \xff\r\n", 400},
+		{"HTTP/1.1", "Host: {\r\n", 400},
+		{"HTTP/1.1", "Host: }\r\n", 400},
+		{"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
+
+		// HTTP/1.0 can lack a host header, but if present
+		// must play by the rules too:
+		{"HTTP/1.0", "", 200},
+		{"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
+		{"HTTP/1.0", "Host: \xff\r\n", 400},
+	}
+	for _, tt := range tests {
+		conn := &testConn{closec: make(chan bool)}
+		io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n")
+
+		ln := &oneConnListener{conn}
+		go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+		<-conn.closec
+		res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+		if err != nil {
+			t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
+			continue
+		}
+		if res.StatusCode != tt.want {
+			t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
+		}
+	}
+}
+
 func BenchmarkClientServer(b *testing.B) {
 	b.ReportAllocs()
 	b.StopTimer()
diff --git a/src/net/http/server.go b/src/net/http/server.go
index cd5f9cf..a00085c 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -686,7 +686,7 @@
 		peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
 		c.bufr.Discard(numLeadingCRorLF(peek))
 	}
-	req, err := ReadRequest(c.bufr)
+	req, err := readRequest(c.bufr, false)
 	c.mu.Unlock()
 	if err != nil {
 		if c.r.hitReadLimit() {
@@ -697,6 +697,18 @@
 	c.lastMethod = req.Method
 	c.r.setInfiniteReadLimit()
 
+	hosts, haveHost := req.Header["Host"]
+	if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) {
+		return nil, badRequestError("missing required Host header")
+	}
+	if len(hosts) > 1 {
+		return nil, badRequestError("too many Host headers")
+	}
+	if len(hosts) == 1 && !validHostHeader(hosts[0]) {
+		return nil, badRequestError("malformed Host header")
+	}
+	delete(req.Header, "Host")
+
 	req.RemoteAddr = c.remoteAddr
 	req.TLS = c.tlsState
 	if body, ok := req.Body.(*body); ok {
@@ -1334,6 +1346,13 @@
 	}
 }
 
+// badRequestError is a literal string (used by in the server in HTML,
+// unescaped) to tell the user why their request was bad. It should
+// be plain text without user info or other embeddded errors.
+type badRequestError string
+
+func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
+
 // Serve a new connection.
 func (c *conn) serve() {
 	c.remoteAddr = c.rwc.RemoteAddr().String()
@@ -1399,7 +1418,11 @@
 			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
 				return // don't reply
 			}
-			io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request")
+			var publicErr string
+			if v, ok := err.(badRequestError); ok {
+				publicErr = ": " + string(v)
+			}
+			io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr)
 			return
 		}