http2: support CONNECT requests
Support CONNECT requests in both the server & transport.
See https://httpwg.github.io/specs/rfc7540.html#CONNECT
When I bundle this into the main Go repo I will also add h1-vs-h2
compatibility tests there, making sure they match behavior. (I now
expect that they do match)
Updates golang/go#13717
Change-Id: I0c65ad47b029419027efb616fed3d8e0e2a363f4
Reviewed-on: https://go-review.googlesource.com/18266
Reviewed-by: Andrew Gerrand <adg@golang.org>
diff --git a/http2/server.go b/http2/server.go
index 9326eb6..eaed248 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -1545,7 +1545,17 @@
func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) {
sc.serveG.check()
rp := &sc.req
- if rp.invalidHeader || rp.method == "" || rp.path == "" ||
+
+ if rp.invalidHeader {
+ return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+ }
+
+ isConnect := rp.method == "CONNECT"
+ if isConnect {
+ if rp.path != "" || rp.scheme != "" || rp.authority == "" {
+ return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+ }
+ } else if rp.method == "" || rp.path == "" ||
(rp.scheme != "https" && rp.scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
@@ -1559,12 +1569,14 @@
// pseudo-header fields"
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
+
bodyOpen := rp.stream.state == stateOpen
if rp.method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
}
var tlsState *tls.ConnectionState // nil if not scheme https
+
if rp.scheme == "https" {
tlsState = sc.tlsState
}
@@ -1605,18 +1617,26 @@
stream: rp.stream,
needsContinue: needsContinue,
}
- // TODO: handle asterisk '*' requests + test
- url, err := url.ParseRequestURI(rp.path)
- if err != nil {
- // TODO: find the right error code?
- return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+ var url_ *url.URL
+ var requestURI string
+ if isConnect {
+ url_ = &url.URL{Host: rp.authority}
+ requestURI = rp.authority // mimic HTTP/1 server behavior
+ } else {
+ var err error
+ // TODO: handle asterisk '*' requests + test
+ url_, err = url.ParseRequestURI(rp.path)
+ if err != nil {
+ return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol}
+ }
+ requestURI = rp.path
}
req := &http.Request{
Method: rp.method,
- URL: url,
+ URL: url_,
RemoteAddr: sc.remoteAddrStr,
Header: rp.header,
- RequestURI: rp.path,
+ RequestURI: requestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
diff --git a/http2/server_test.go b/http2/server_test.go
index 4f22cde..88f3cfb 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -901,6 +901,60 @@
st.wantRSTStream(1, ErrCodeProtocol)
}
+func TestServer_Request_Connect(t *testing.T) {
+ testServerRequest(t, func(st *serverTester) {
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeaderRaw(
+ ":method", "CONNECT",
+ ":authority", "example.com:123",
+ ),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ }, func(r *http.Request) {
+ if g, w := r.Method, "CONNECT"; g != w {
+ t.Errorf("Method = %q; want %q", g, w)
+ }
+ if g, w := r.RequestURI, "example.com:123"; g != w {
+ t.Errorf("RequestURI = %q; want %q", g, w)
+ }
+ if g, w := r.URL.Host, "example.com:123"; g != w {
+ t.Errorf("URL.Host = %q; want %q", g, w)
+ }
+ })
+}
+
+func TestServer_Request_Connect_InvalidPath(t *testing.T) {
+ testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeaderRaw(
+ ":method", "CONNECT",
+ ":authority", "example.com:123",
+ ":path", "/bogus",
+ ),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ })
+}
+
+func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
+ testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeaderRaw(
+ ":method", "CONNECT",
+ ":authority", "example.com:123",
+ ":scheme", "https",
+ ),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ })
+}
+
func TestServer_Ping(t *testing.T) {
st := newServerTester(t, nil)
defer st.Close()
@@ -1222,7 +1276,7 @@
// test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
@@ -1240,7 +1294,7 @@
// test HEADERS w/o EndHeaders + PING (should get rejected)
func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
@@ -1255,7 +1309,7 @@
// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
@@ -1271,7 +1325,7 @@
// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1,
BlockFragment: st.encodeHeader(),
@@ -1286,7 +1340,7 @@
// No HEADERS on stream 0.
func TestServer_Rejects_Headers0(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
st.fr.AllowIllegalWrites = true
st.writeHeaders(HeadersFrameParam{
StreamID: 0,
@@ -1299,7 +1353,7 @@
// No CONTINUATION on stream 0.
func TestServer_Rejects_Continuation0(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
st.fr.AllowIllegalWrites = true
if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
t.Fatal(err)
@@ -1308,7 +1362,7 @@
}
func TestServer_Rejects_PushPromise(t *testing.T) {
- testServerRejects(t, func(st *serverTester) {
+ testServerRejectsConn(t, func(st *serverTester) {
pp := PushPromiseParam{
StreamID: 1,
PromiseID: 3,
@@ -1319,10 +1373,10 @@
})
}
-// testServerRejects tests that the server hangs up with a GOAWAY
+// testServerRejectsConn tests that the server hangs up with a GOAWAY
// frame and a server close after the client does something
// deserving a CONNECTION_ERROR.
-func testServerRejects(t *testing.T, writeReq func(*serverTester)) {
+func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
st.addLogFilter("connection error: PROTOCOL_ERROR")
defer st.Close()
@@ -1348,6 +1402,16 @@
}
}
+// testServerRejectsStream tests that the server sends a RST_STREAM with the provided
+// error code after a client sends a bogus request.
+func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
+ defer st.Close()
+ st.greet()
+ writeReq(st)
+ st.wantRSTStream(1, code)
+}
+
// testServerRequest sets up an idle HTTP/2 connection and lets you
// write a single request with writeReq, and then verify that the
// *http.Request is built correctly in checkReq.
diff --git a/http2/transport.go b/http2/transport.go
index 781be9d..db2af1e 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -775,8 +775,10 @@
// [RFC3986]).
cc.writeHeader(":authority", host) // probably not right for all sites
cc.writeHeader(":method", req.Method)
- cc.writeHeader(":path", req.URL.RequestURI())
- cc.writeHeader(":scheme", "https")
+ if req.Method != "CONNECT" {
+ cc.writeHeader(":path", req.URL.RequestURI())
+ cc.writeHeader(":scheme", "https")
+ }
if trailers != "" {
cc.writeHeader("trailer", trailers)
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index ef8eaa9..8e91678 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -760,3 +760,62 @@
t.Fatal(err)
}
}
+
+func TestTransportConnectRequest(t *testing.T) {
+ gotc := make(chan *http.Request, 1)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ gotc <- r
+ }, optOnlyServer)
+ defer st.Close()
+
+ u, err := url.Parse(st.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+ c := &http.Client{Transport: tr}
+
+ tests := []struct {
+ req *http.Request
+ want string
+ }{
+ {
+ req: &http.Request{
+ Method: "CONNECT",
+ Header: http.Header{},
+ URL: u,
+ },
+ want: u.Host,
+ },
+ {
+ req: &http.Request{
+ Method: "CONNECT",
+ Header: http.Header{},
+ URL: u,
+ Host: "example.com:123",
+ },
+ want: "example.com:123",
+ },
+ }
+
+ for i, tt := range tests {
+ res, err := c.Do(tt.req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ res.Body.Close()
+ req := <-gotc
+ if req.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", req.Method)
+ }
+ if req.Host != tt.want {
+ t.Errorf("Host = %q; want %q", req.Host, tt.want)
+ }
+ if req.URL.Host != tt.want {
+ t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
+ }
+ }
+}