http2: don't send Connection-level headers in Transport

Accept common things that users might try to do to be helpful (managed
by net/http anyway, and previously legal or at best ignored), like

  Connection: close
  Connection: keep-alive
  Transfer-Encoding: chunked

But reject all other connection-level headers, per http2 spec. The
Google GFE enforces this, so we need to filter these before sending,
and give users a better error message for the ones we can't safely
filter. That is, reject any connection-level header that we don't know
the meaning of.

This CL also makes "Connection: close" mean the same as Request.Close,
and respects that as well, which was previously ignored in http2.

Mostly tests.

Updates golang/go#14227

Change-Id: I06e20286f71e8416149588e2c6274a3fce68033b
Reviewed-on: https://go-review.googlesource.com/19223
Reviewed-by: Andrew Gerrand <adg@golang.org>
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 8e2e5f2..0aaa067 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -562,7 +562,27 @@
 	return 0
 }
 
+// checkConnHeaders checks whether req has any invalid connection-level headers.
+// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields.
+// Certain headers are special-cased as okay but not transmitted later.
+func checkConnHeaders(req *http.Request) error {
+	if v := req.Header.Get("Upgrade"); v != "" {
+		return errors.New("http2: invalid Upgrade request header")
+	}
+	if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 {
+		return errors.New("http2: invalid Transfer-Encoding request header")
+	}
+	if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 {
+		return errors.New("http2: invalid Connection request header")
+	}
+	return nil
+}
+
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
+	if err := checkConnHeaders(req); err != nil {
+		return nil, err
+	}
+
 	trailers, err := commaSeparatedTrailers(req)
 	if err != nil {
 		return nil, err
@@ -914,13 +934,24 @@
 	var didUA bool
 	for k, vv := range req.Header {
 		lowKey := strings.ToLower(k)
-		if lowKey == "host" || lowKey == "content-length" {
+		switch lowKey {
+		case "host", "content-length":
+			// Host is :authority, already sent.
+			// Content-Length is automatic, set below.
 			continue
-		}
-		if lowKey == "user-agent" {
+		case "connection", "proxy-connection", "transfer-encoding", "upgrade":
+			// Per 8.1.2.2 Connection-Specific Header
+			// Fields, don't send connection-specific
+			// fields. We deal with these earlier in
+			// RoundTrip, deciding whether they're
+			// error-worthy, but we don't want to mutate
+			// the user's *Request so at this point, just
+			// skip over them at this point.
+			continue
+		case "user-agent":
 			// Match Go's http1 behavior: at most one
-			// User-Agent.  If set to nil or empty string,
-			// then omit it.  Otherwise if not mentioned,
+			// User-Agent. If set to nil or empty string,
+			// then omit it. Otherwise if not mentioned,
 			// include the default (below).
 			didUA = true
 			if len(vv) < 1 {
@@ -1030,8 +1061,9 @@
 
 // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
 type clientConnReadLoop struct {
-	cc        *ClientConn
-	activeRes map[uint32]*clientStream // keyed by streamID
+	cc            *ClientConn
+	activeRes     map[uint32]*clientStream // keyed by streamID
+	closeWhenIdle bool
 
 	hdec *hpack.Decoder
 
@@ -1091,7 +1123,7 @@
 
 func (rl *clientConnReadLoop) run() error {
 	cc := rl.cc
-	closeWhenIdle := cc.t.disableKeepAlives()
+	rl.closeWhenIdle = cc.t.disableKeepAlives()
 	gotReply := false // ever saw a reply
 	for {
 		f, err := cc.fr.ReadFrame()
@@ -1140,7 +1172,7 @@
 		if err != nil {
 			return err
 		}
-		if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
+		if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
 			cc.closeIfIdle()
 		}
 	}
@@ -1407,6 +1439,9 @@
 	}
 	cs.bufPipe.closeWithErrorAndCode(err, code)
 	delete(rl.activeRes, cs.ID)
+	if cs.req.Close || cs.req.Header.Get("Connection") == "close" {
+		rl.closeWhenIdle = true
+	}
 }
 
 func (cs *clientStream) copyTrailers() {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 868fd1f..827bfdb 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -20,6 +20,7 @@
 	"net/url"
 	"os"
 	"reflect"
+	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -100,8 +101,7 @@
 		t.Errorf("Body = %q; want %q", slurp, body)
 	}
 }
-
-func TestTransportReusesConns(t *testing.T) {
+func onSameConn(t *testing.T, modReq func(*http.Request)) bool {
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 		io.WriteString(w, r.RemoteAddr)
 	}, optOnlyServer)
@@ -113,6 +113,7 @@
 		if err != nil {
 			t.Fatal(err)
 		}
+		modReq(req)
 		res, err := tr.RoundTrip(req)
 		if err != nil {
 			t.Fatal(err)
@@ -130,8 +131,24 @@
 	}
 	first := get()
 	second := get()
-	if first != second {
-		t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
+	return first == second
+}
+
+func TestTransportReusesConns(t *testing.T) {
+	if !onSameConn(t, func(*http.Request) {}) {
+		t.Errorf("first and second responses were on different connections")
+	}
+}
+
+func TestTransportReusesConn_RequestClose(t *testing.T) {
+	if onSameConn(t, func(r *http.Request) { r.Close = true }) {
+		t.Errorf("first and second responses were not on different connections")
+	}
+}
+
+func TestTransportReusesConn_ConnClose(t *testing.T) {
+	if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) {
+		t.Errorf("first and second responses were not on different connections")
 	}
 }
 
@@ -1549,3 +1566,97 @@
 	}
 	defer res.Body.Close()
 }
+
+// RFC 7540 section 8.1.2.2
+func TestTransportRejectsConnHeaders(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		var got []string
+		for k := range r.Header {
+			got = append(got, k)
+		}
+		sort.Strings(got)
+		w.Header().Set("Got-Header", strings.Join(got, ","))
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+
+	tests := []struct {
+		key   string
+		value []string
+		want  string
+	}{
+		{
+			key:   "Upgrade",
+			value: []string{"anything"},
+			want:  "ERROR: http2: invalid Upgrade request header",
+		},
+		{
+			key:   "Connection",
+			value: []string{"foo"},
+			want:  "ERROR: http2: invalid Connection request header",
+		},
+		{
+			key:   "Connection",
+			value: []string{"close"},
+			want:  "Accept-Encoding,User-Agent",
+		},
+		{
+			key:   "Connection",
+			value: []string{"close", "something-else"},
+			want:  "ERROR: http2: invalid Connection request header",
+		},
+		{
+			key:   "Connection",
+			value: []string{"keep-alive"},
+			want:  "Accept-Encoding,User-Agent",
+		},
+		{
+			key:   "Proxy-Connection", // just deleted and ignored
+			value: []string{"keep-alive"},
+			want:  "Accept-Encoding,User-Agent",
+		},
+		{
+			key:   "Transfer-Encoding",
+			value: []string{""},
+			want:  "Accept-Encoding,User-Agent",
+		},
+		{
+			key:   "Transfer-Encoding",
+			value: []string{"foo"},
+			want:  "ERROR: http2: invalid Transfer-Encoding request header",
+		},
+		{
+			key:   "Transfer-Encoding",
+			value: []string{"chunked"},
+			want:  "Accept-Encoding,User-Agent",
+		},
+		{
+			key:   "Transfer-Encoding",
+			value: []string{"chunked", "other"},
+			want:  "ERROR: http2: invalid Transfer-Encoding request header",
+		},
+		{
+			key:   "Content-Length",
+			value: []string{"123"},
+			want:  "Accept-Encoding,User-Agent",
+		},
+	}
+
+	for _, tt := range tests {
+		req, _ := http.NewRequest("GET", st.ts.URL, nil)
+		req.Header[tt.key] = tt.value
+		res, err := tr.RoundTrip(req)
+		var got string
+		if err != nil {
+			got = fmt.Sprintf("ERROR: %v", err)
+		} else {
+			got = res.Header.Get("Got-Header")
+			res.Body.Close()
+		}
+		if got != tt.want {
+			t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
+		}
+	}
+}