http2/h2c: support direct hand off of h2c-upgrade connections

The initial request on an h2c-upgraded connection is written as an
HTTP/1 request, with the response sent as an HTTP/2 stream.

The h2c package handled this request by constructing a sequence of
bytes representing an HTTP/2 stream containing the initial request,
prepending those bytes to the remainder of the connection, and
presenting that to the HTTP/2 server as if no upgrade had happened.
This translation did not handle request bodies. Handling request
bodies under this model would be difficult, since it would require
also translating the HTTP/2 flow control.

Rewrite the h2c upgrade to explicitly hand off the request to the
HTTP/2 server instead.

Fixes golang/go#52882.

Change-Id: I26e0f12e2b1c8b48fd36ba47baea076424983553
Reviewed-on: https://go-review.googlesource.com/c/net/+/407454
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go
index c0970d8..c3df711 100644
--- a/http2/h2c/h2c.go
+++ b/http2/h2c/h2c.go
@@ -12,7 +12,6 @@
 	"bufio"
 	"bytes"
 	"encoding/base64"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -25,7 +24,6 @@
 
 	"golang.org/x/net/http/httpguts"
 	"golang.org/x/net/http2"
-	"golang.org/x/net/http2/hpack"
 )
 
 var (
@@ -61,6 +59,10 @@
 // Once a request is recognized as h2c, we hijack the connection and convert it
 // to an HTTP/2 connection which is understandable to s.ServeConn. (s.ServeConn
 // understands HTTP/2 except for the h2c part of it.)
+//
+// The first request on an h2c connection is read entirely into memory before
+// the Handler is called. To limit the memory consumed by this request, wrap
+// the result of NewHandler in an http.MaxBytesHandler.
 func NewHandler(h http.Handler, s *http2.Server) http.Handler {
 	return &h2cHandler{
 		Handler: h,
@@ -83,24 +85,31 @@
 			return
 		}
 		defer conn.Close()
-
 		s.s.ServeConn(conn, &http2.ServeConnOpts{
-			Context: r.Context(),
-			Handler: s.Handler,
+			Context:          r.Context(),
+			Handler:          s.Handler,
+			SawClientPreface: true,
 		})
 		return
 	}
 	// Handle Upgrade to h2c (RFC 7540 Section 3.2)
-	if conn, err := h2cUpgrade(w, r); err == nil {
+	if isH2CUpgrade(r.Header) {
+		conn, settings, err := h2cUpgrade(w, r)
+		if err != nil {
+			if http2VerboseLogs {
+				log.Printf("h2c: error h2c upgrade: %v", err)
+			}
+			return
+		}
 		defer conn.Close()
-
 		s.s.ServeConn(conn, &http2.ServeConnOpts{
-			Context: r.Context(),
-			Handler: s.Handler,
+			Context:        r.Context(),
+			Handler:        s.Handler,
+			UpgradeRequest: r,
+			Settings:       settings,
 		})
 		return
 	}
-
 	s.Handler.ServeHTTP(w, r)
 	return
 }
@@ -113,11 +122,11 @@
 func initH2CWithPriorKnowledge(w http.ResponseWriter) (net.Conn, error) {
 	hijacker, ok := w.(http.Hijacker)
 	if !ok {
-		panic("Hijack not supported.")
+		return nil, errors.New("h2c: connection does not support Hijack")
 	}
 	conn, rw, err := hijacker.Hijack()
 	if err != nil {
-		panic(fmt.Sprintf("Hijack failed: %v", err))
+		return nil, err
 	}
 
 	const expectedBody = "SM\r\n\r\n"
@@ -125,249 +134,40 @@
 	buf := make([]byte, len(expectedBody))
 	n, err := io.ReadFull(rw, buf)
 	if err != nil {
-		return nil, fmt.Errorf("could not read from the buffer: %s", err)
+		return nil, fmt.Errorf("h2c: error reading client preface: %s", err)
 	}
 
 	if string(buf[:n]) == expectedBody {
-		c := &rwConn{
-			Conn:      conn,
-			Reader:    io.MultiReader(strings.NewReader(http2.ClientPreface), rw),
-			BufWriter: rw.Writer,
-		}
-		return c, nil
+		return newBufConn(conn, rw), nil
 	}
 
 	conn.Close()
-	if http2VerboseLogs {
-		log.Printf(
-			"h2c: missing the request body portion of the client preface. Wanted: %v Got: %v",
-			[]byte(expectedBody),
-			buf[0:n],
-		)
-	}
-	return nil, errors.New("invalid client preface")
-}
-
-// drainClientPreface reads a single instance of the HTTP/2 client preface from
-// the supplied reader.
-func drainClientPreface(r io.Reader) error {
-	var buf bytes.Buffer
-	prefaceLen := int64(len(http2.ClientPreface))
-	n, err := io.CopyN(&buf, r, prefaceLen)
-	if err != nil {
-		return err
-	}
-	if n != prefaceLen || buf.String() != http2.ClientPreface {
-		return fmt.Errorf("Client never sent: %s", http2.ClientPreface)
-	}
-	return nil
+	return nil, errors.New("h2c: invalid client preface")
 }
 
 // h2cUpgrade establishes a h2c connection using the HTTP/1 upgrade (Section 3.2).
-func h2cUpgrade(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
-	if !isH2CUpgrade(r.Header) {
-		return nil, errors.New("non-conforming h2c headers")
-	}
-
-	// Initial bytes we put into conn to fool http2 server
-	initBytes, _, err := convertH1ReqToH2(r)
+func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []byte, err error) {
+	settings, err = getH2Settings(r.Header)
 	if err != nil {
-		return nil, err
+		return nil, nil, err
 	}
-
 	hijacker, ok := w.(http.Hijacker)
 	if !ok {
-		return nil, errors.New("hijack not supported.")
+		return nil, nil, errors.New("h2c: connection does not support Hijack")
 	}
+
+	body, _ := io.ReadAll(r.Body)
+	r.Body = io.NopCloser(bytes.NewBuffer(body))
+
 	conn, rw, err := hijacker.Hijack()
 	if err != nil {
-		return nil, fmt.Errorf("hijack failed: %v", err)
+		return nil, nil, err
 	}
 
 	rw.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n" +
 		"Connection: Upgrade\r\n" +
 		"Upgrade: h2c\r\n\r\n"))
-	rw.Flush()
-
-	// A conforming client will now send an H2 client preface which need to drain
-	// since we already sent this.
-	if err := drainClientPreface(rw); err != nil {
-		return nil, err
-	}
-
-	c := &rwConn{
-		Conn:      conn,
-		Reader:    io.MultiReader(initBytes, rw),
-		BufWriter: newSettingsAckSwallowWriter(rw.Writer),
-	}
-	return c, nil
-}
-
-// convert the data contained in the HTTP/1 upgrade request into the HTTP/2
-// version in byte form.
-func convertH1ReqToH2(r *http.Request) (*bytes.Buffer, []http2.Setting, error) {
-	h2Bytes := bytes.NewBuffer([]byte((http2.ClientPreface)))
-	framer := http2.NewFramer(h2Bytes, nil)
-	settings, err := getH2Settings(r.Header)
-	if err != nil {
-		return nil, nil, err
-	}
-
-	if err := framer.WriteSettings(settings...); err != nil {
-		return nil, nil, err
-	}
-
-	headerBytes, err := getH2HeaderBytes(r, getMaxHeaderTableSize(settings))
-	if err != nil {
-		return nil, nil, err
-	}
-
-	maxFrameSize := int(getMaxFrameSize(settings))
-	needOneHeader := len(headerBytes) < maxFrameSize
-	err = framer.WriteHeaders(http2.HeadersFrameParam{
-		StreamID:      1,
-		BlockFragment: headerBytes,
-		EndHeaders:    needOneHeader,
-	})
-	if err != nil {
-		return nil, nil, err
-	}
-
-	for i := maxFrameSize; i < len(headerBytes); i += maxFrameSize {
-		if len(headerBytes)-i > maxFrameSize {
-			if err := framer.WriteContinuation(1,
-				false, // endHeaders
-				headerBytes[i:maxFrameSize]); err != nil {
-				return nil, nil, err
-			}
-		} else {
-			if err := framer.WriteContinuation(1,
-				true, // endHeaders
-				headerBytes[i:]); err != nil {
-				return nil, nil, err
-			}
-		}
-	}
-
-	return h2Bytes, settings, nil
-}
-
-// getMaxFrameSize returns the SETTINGS_MAX_FRAME_SIZE. If not present default
-// value is 16384 as specified by RFC 7540 Section 6.5.2.
-func getMaxFrameSize(settings []http2.Setting) uint32 {
-	for _, setting := range settings {
-		if setting.ID == http2.SettingMaxFrameSize {
-			return setting.Val
-		}
-	}
-	return 16384
-}
-
-// getMaxHeaderTableSize returns the SETTINGS_HEADER_TABLE_SIZE. If not present
-// default value is 4096 as specified by RFC 7540 Section 6.5.2.
-func getMaxHeaderTableSize(settings []http2.Setting) uint32 {
-	for _, setting := range settings {
-		if setting.ID == http2.SettingHeaderTableSize {
-			return setting.Val
-		}
-	}
-	return 4096
-}
-
-// bufWriter is a Writer interface that also has a Flush method.
-type bufWriter interface {
-	io.Writer
-	Flush() error
-}
-
-// rwConn implements net.Conn but overrides Read and Write so that reads and
-// writes are forwarded to the provided io.Reader and bufWriter.
-type rwConn struct {
-	net.Conn
-	io.Reader
-	BufWriter bufWriter
-}
-
-// Read forwards reads to the underlying Reader.
-func (c *rwConn) Read(p []byte) (int, error) {
-	return c.Reader.Read(p)
-}
-
-// Write forwards writes to the underlying bufWriter and immediately flushes.
-func (c *rwConn) Write(p []byte) (int, error) {
-	n, err := c.BufWriter.Write(p)
-	if err := c.BufWriter.Flush(); err != nil {
-		return 0, err
-	}
-	return n, err
-}
-
-// settingsAckSwallowWriter is a writer that normally forwards bytes to its
-// underlying Writer, but swallows the first SettingsAck frame that it sees.
-type settingsAckSwallowWriter struct {
-	Writer     *bufio.Writer
-	buf        []byte
-	didSwallow bool
-}
-
-// newSettingsAckSwallowWriter returns a new settingsAckSwallowWriter.
-func newSettingsAckSwallowWriter(w *bufio.Writer) *settingsAckSwallowWriter {
-	return &settingsAckSwallowWriter{
-		Writer:     w,
-		buf:        make([]byte, 0),
-		didSwallow: false,
-	}
-}
-
-// Write implements io.Writer interface. Normally forwards bytes to w.Writer,
-// except for the first Settings ACK frame that it sees.
-func (w *settingsAckSwallowWriter) Write(p []byte) (int, error) {
-	if !w.didSwallow {
-		w.buf = append(w.buf, p...)
-		// Process all the frames we have collected into w.buf
-		for {
-			// Append until we get full frame header which is 9 bytes
-			if len(w.buf) < 9 {
-				break
-			}
-			// Check if we have collected a whole frame.
-			fh, err := http2.ReadFrameHeader(bytes.NewBuffer(w.buf))
-			if err != nil {
-				// Corrupted frame, fail current Write
-				return 0, err
-			}
-			fSize := fh.Length + 9
-			if uint32(len(w.buf)) < fSize {
-				// Have not collected whole frame. Stop processing buf, and withhold on
-				// forward bytes to w.Writer until we get the full frame.
-				break
-			}
-
-			// We have now collected a whole frame.
-			if fh.Type == http2.FrameSettings && fh.Flags.Has(http2.FlagSettingsAck) {
-				// If Settings ACK frame, do not forward to underlying writer, remove
-				// bytes from w.buf, and record that we have swallowed Settings Ack
-				// frame.
-				w.didSwallow = true
-				w.buf = w.buf[fSize:]
-				continue
-			}
-
-			// Not settings ack frame. Forward bytes to w.Writer.
-			if _, err := w.Writer.Write(w.buf[:fSize]); err != nil {
-				// Couldn't forward bytes. Fail current Write.
-				return 0, err
-			}
-			w.buf = w.buf[fSize:]
-		}
-		return len(p), nil
-	}
-	return w.Writer.Write(p)
-}
-
-// Flush calls w.Writer.Flush.
-func (w *settingsAckSwallowWriter) Flush() error {
-	return w.Writer.Flush()
+	return newBufConn(conn, rw), settings, nil
 }
 
 // isH2CUpgrade returns true if the header properly request an upgrade to h2c
@@ -377,9 +177,8 @@
 		httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings")
 }
 
-// getH2Settings returns the []http2.Setting that are encoded in the
-// HTTP2-Settings header.
-func getH2Settings(h http.Header) ([]http2.Setting, error) {
+// getH2Settings returns the settings in the HTTP2-Settings header.
+func getH2Settings(h http.Header) ([]byte, error) {
 	vals, ok := h[textproto.CanonicalMIMEHeaderKey("HTTP2-Settings")]
 	if !ok {
 		return nil, errors.New("missing HTTP2-Settings header")
@@ -387,115 +186,40 @@
 	if len(vals) != 1 {
 		return nil, fmt.Errorf("expected 1 HTTP2-Settings. Got: %v", vals)
 	}
-	settings, err := decodeSettings(vals[0])
+	settings, err := base64.RawURLEncoding.DecodeString(vals[0])
 	if err != nil {
-		return nil, fmt.Errorf("Invalid HTTP2-Settings: %q", vals[0])
+		return nil, err
 	}
 	return settings, nil
 }
 
-// decodeSettings decodes the base64url header value of the HTTP2-Settings
-// header. RFC 7540 Section 3.2.1.
-func decodeSettings(headerVal string) ([]http2.Setting, error) {
-	b, err := base64.RawURLEncoding.DecodeString(headerVal)
-	if err != nil {
-		return nil, err
+func newBufConn(conn net.Conn, rw *bufio.ReadWriter) net.Conn {
+	rw.Flush()
+	if rw.Reader.Buffered() == 0 {
+		// If there's no buffered data to be read,
+		// we can just discard the bufio.ReadWriter.
+		return conn
 	}
-	if len(b)%6 != 0 {
-		return nil, err
-	}
-	settings := make([]http2.Setting, 0)
-	for i := 0; i < len(b)/6; i++ {
-		settings = append(settings, http2.Setting{
-			ID:  http2.SettingID(binary.BigEndian.Uint16(b[i*6 : i*6+2])),
-			Val: binary.BigEndian.Uint32(b[i*6+2 : i*6+6]),
-		})
-	}
-
-	return settings, nil
+	return &bufConn{conn, rw.Reader}
 }
 
-// getH2HeaderBytes return the headers in r a []bytes encoded by HPACK.
-func getH2HeaderBytes(r *http.Request, maxHeaderTableSize uint32) ([]byte, error) {
-	headerBytes := bytes.NewBuffer(nil)
-	hpackEnc := hpack.NewEncoder(headerBytes)
-	hpackEnc.SetMaxDynamicTableSize(maxHeaderTableSize)
-
-	// Section 8.1.2.3
-	err := hpackEnc.WriteField(hpack.HeaderField{
-		Name:  ":method",
-		Value: r.Method,
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	err = hpackEnc.WriteField(hpack.HeaderField{
-		Name:  ":scheme",
-		Value: "http",
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	err = hpackEnc.WriteField(hpack.HeaderField{
-		Name:  ":authority",
-		Value: r.Host,
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	path := r.URL.Path
-	if r.URL.RawQuery != "" {
-		path = strings.Join([]string{path, r.URL.RawQuery}, "?")
-	}
-	err = hpackEnc.WriteField(hpack.HeaderField{
-		Name:  ":path",
-		Value: path,
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	// TODO Implement Section 8.3
-
-	for header, values := range r.Header {
-		// Skip non h2 headers
-		if isNonH2Header(header) {
-			continue
-		}
-		for _, v := range values {
-			err := hpackEnc.WriteField(hpack.HeaderField{
-				Name:  strings.ToLower(header),
-				Value: v,
-			})
-			if err != nil {
-				return nil, err
-			}
-		}
-	}
-	return headerBytes.Bytes(), nil
+// bufConn wraps a net.Conn, but reads drain the bufio.Reader first.
+type bufConn struct {
+	net.Conn
+	*bufio.Reader
 }
 
-// Connection specific headers listed in RFC 7540 Section 8.1.2.2 that are not
-// suppose to be transferred to HTTP/2. The Http2-Settings header is skipped
-// since already use to create the HTTP/2 SETTINGS frame.
-var nonH2Headers = []string{
-	"Connection",
-	"Keep-Alive",
-	"Proxy-Connection",
-	"Transfer-Encoding",
-	"Upgrade",
-	"Http2-Settings",
-}
-
-// isNonH2Header returns true if header should not be transferred to HTTP/2.
-func isNonH2Header(header string) bool {
-	for _, nonH2h := range nonH2Headers {
-		if header == nonH2h {
-			return true
-		}
+func (c *bufConn) Read(p []byte) (int, error) {
+	if c.Reader == nil {
+		return c.Conn.Read(p)
 	}
-	return false
+	n := c.Reader.Buffered()
+	if n == 0 {
+		c.Reader = nil
+		return c.Conn.Read(p)
+	}
+	if n < len(p) {
+		p = p[:n]
+	}
+	return c.Reader.Read(p)
 }
diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go
index d315632..3e5a2eb 100644
--- a/http2/h2c/h2c_test.go
+++ b/http2/h2c/h2c_test.go
@@ -5,8 +5,6 @@
 package h2c
 
 import (
-	"bufio"
-	"bytes"
 	"context"
 	"crypto/tls"
 	"fmt"
@@ -20,34 +18,6 @@
 	"golang.org/x/net/http2"
 )
 
-func TestSettingsAckSwallowWriter(t *testing.T) {
-	var buf bytes.Buffer
-	swallower := newSettingsAckSwallowWriter(bufio.NewWriter(&buf))
-	fw := http2.NewFramer(swallower, nil)
-	fw.WriteSettings(http2.Setting{ID: http2.SettingMaxFrameSize, Val: 2})
-	fw.WriteSettingsAck()
-	fw.WriteData(1, true, []byte{})
-	swallower.Flush()
-
-	fr := http2.NewFramer(nil, bufio.NewReader(&buf))
-
-	f, err := fr.ReadFrame()
-	if err != nil {
-		t.Fatal(err)
-	}
-	if f.Header().Type != http2.FrameSettings {
-		t.Fatalf("Expected first frame to be SETTINGS. Got: %v", f.Header().Type)
-	}
-
-	f, err = fr.ReadFrame()
-	if err != nil {
-		t.Fatal(err)
-	}
-	if f.Header().Type != http2.FrameData {
-		t.Fatalf("Expected first frame to be DATA. Got: %v", f.Header().Type)
-	}
-}
-
 func ExampleNewHandler() {
 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprint(w, "Hello world")
diff --git a/http2/server.go b/http2/server.go
index 2d859af..47524a6 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -315,6 +315,20 @@
 	// requests. If nil, BaseConfig.Handler is used. If BaseConfig
 	// or BaseConfig.Handler is nil, http.DefaultServeMux is used.
 	Handler http.Handler
+
+	// UpgradeRequest is an initial request received on a connection
+	// undergoing an h2c upgrade. The request body must have been
+	// completely read from the connection before calling ServeConn,
+	// and the 101 Switching Protocols response written.
+	UpgradeRequest *http.Request
+
+	// Settings is the decoded contents of the HTTP2-Settings header
+	// in an h2c upgrade request.
+	Settings []byte
+
+	// SawClientPreface is set if the HTTP/2 connection preface
+	// has already been read from the connection.
+	SawClientPreface bool
 }
 
 func (o *ServeConnOpts) context() context.Context {
@@ -383,6 +397,7 @@
 		headerTableSize:             initialHeaderTableSize,
 		serveG:                      newGoroutineLock(),
 		pushEnabled:                 true,
+		sawClientPreface:            opts.SawClientPreface,
 	}
 
 	s.state.registerConn(sc)
@@ -465,9 +480,27 @@
 		}
 	}
 
+	if opts.Settings != nil {
+		fr := &SettingsFrame{
+			FrameHeader: FrameHeader{valid: true},
+			p:           opts.Settings,
+		}
+		if err := fr.ForeachSetting(sc.processSetting); err != nil {
+			sc.rejectConn(ErrCodeProtocol, "invalid settings")
+			return
+		}
+		opts.Settings = nil
+	}
+
 	if hook := testHookGetServerConn; hook != nil {
 		hook(sc)
 	}
+
+	if opts.UpgradeRequest != nil {
+		sc.upgradeRequest(opts.UpgradeRequest)
+		opts.UpgradeRequest = nil
+	}
+
 	sc.serve()
 }
 
@@ -512,6 +545,7 @@
 	// Everything following is owned by the serve loop; use serveG.check():
 	serveG                      goroutineLock // used to verify funcs are on serve()
 	pushEnabled                 bool
+	sawClientPreface            bool // preface has already been read, used in h2c upgrade
 	sawFirstSettings            bool // got the initial SETTINGS frame after the preface
 	needToSendSettingsAck       bool
 	unackedSettings             int    // how many SETTINGS have we sent without ACKs?
@@ -974,6 +1008,9 @@
 // returns errPrefaceTimeout on timeout, or an error if the greeting
 // is invalid.
 func (sc *serverConn) readPreface() error {
+	if sc.sawClientPreface {
+		return nil
+	}
 	errc := make(chan error, 1)
 	go func() {
 		// Read the client preface
@@ -1915,6 +1952,26 @@
 	return nil
 }
 
+func (sc *serverConn) upgradeRequest(req *http.Request) {
+	sc.serveG.check()
+	id := uint32(1)
+	sc.maxClientStreamID = id
+	st := sc.newStream(id, 0, stateHalfClosedRemote)
+	st.reqTrailer = req.Trailer
+	if st.reqTrailer != nil {
+		st.trailer = make(http.Header)
+	}
+	rw := sc.newResponseWriter(st, req)
+
+	// Disable any read deadline set by the net/http package
+	// prior to the upgrade.
+	if sc.hs.ReadTimeout != 0 {
+		sc.conn.SetReadDeadline(time.Time{})
+	}
+
+	go sc.runHandler(rw, req, sc.handler.ServeHTTP)
+}
+
 func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
 	sc := st.sc
 	sc.serveG.check()
@@ -2145,6 +2202,11 @@
 	}
 	req = req.WithContext(st.ctx)
 
+	rw := sc.newResponseWriter(st, req)
+	return rw, req, nil
+}
+
+func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *responseWriter {
 	rws := responseWriterStatePool.Get().(*responseWriterState)
 	bwSave := rws.bw
 	*rws = responseWriterState{} // zero all the fields
@@ -2153,10 +2215,7 @@
 	rws.bw.Reset(chunkWriter{rws})
 	rws.stream = st
 	rws.req = req
-	rws.body = body
-
-	rw := &responseWriter{rws: rws}
-	return rw, req, nil
+	return &responseWriter{rws: rws}
 }
 
 // Run on its own goroutine.
@@ -2371,7 +2430,6 @@
 	// immutable within a request:
 	stream *stream
 	req    *http.Request
-	body   *requestBody // to close at end of request, if DATA frames didn't
 	conn   *serverConn
 
 	// TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc