http2: implement support for server push

This makes x/net/http2's ResponseWriter implement the new interface,
http.Pusher. This new interface requires Go 1.8. When compiled against
older versions of Go, the ResponseWriter does not have a Push method.

Fixes golang/go#13443

Change-Id: I8486ffe4bb5562a94270ace21e90e8c9a4653da0
Reviewed-on: https://go-review.googlesource.com/29439
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/server.go b/http2/server.go
index 50bc112..8fed0b1 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -33,6 +33,7 @@
 	"fmt"
 	"io"
 	"log"
+	"math"
 	"net"
 	"net/http"
 	"net/textproto"
@@ -262,9 +263,11 @@
 		streams:           make(map[uint32]*stream),
 		readFrameCh:       make(chan readFrameResult),
 		wantWriteFrameCh:  make(chan FrameWriteRequest, 8),
+		wantStartPushCh:   make(chan startPushRequest, 8),
 		wroteFrameCh:      make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
 		bodyReadCh:        make(chan bodyReadMsg),         // buffering doesn't matter either way
 		doneServing:       make(chan struct{}),
+		clientMaxStreams:  math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
 		advMaxStreams:     s.maxConcurrentStreams(),
 		initialWindowSize: initialWindowSize,
 		maxFrameSize:      initialMaxFrameSize,
@@ -361,6 +364,7 @@
 	doneServing      chan struct{}          // closed when serverConn.serve ends
 	readFrameCh      chan readFrameResult   // written by serverConn.readFrames
 	wantWriteFrameCh chan FrameWriteRequest // from handlers -> serve
+	wantStartPushCh  chan startPushRequest  // from handlers -> serve
 	wroteFrameCh     chan frameWriteResult  // from writeFrameAsync -> serve, tickles more frame writes
 	bodyReadCh       chan bodyReadMsg       // from handlers -> serve
 	testHookCh       chan func(int)         // code to run on the serve loop
@@ -378,8 +382,10 @@
 	unackedSettings       int    // how many SETTINGS have we sent without ACKs?
 	clientMaxStreams      uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
 	advMaxStreams         uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
-	curOpenStreams        uint32 // client's number of open streams
-	maxStreamID           uint32 // max ever seen
+	curClientStreams      uint32 // number of open streams initiated by the client
+	curPushedStreams      uint32 // number of open streams initiated by server push
+	maxStreamID           uint32 // max ever seen from client
+	maxPushPromiseID      uint32 // ID of the last push promise, or 0 if there have been no pushes
 	streams               map[uint32]*stream
 	initialWindowSize     int32
 	maxFrameSize          int32
@@ -457,7 +463,7 @@
 
 func (sc *serverConn) state(streamID uint32) (streamState, *stream) {
 	sc.serveG.check()
-	// http://http2.github.io/http2-spec/#rfc.section.5.1
+	// http://tools.ietf.org/html/rfc7540#section-5.1
 	if st, ok := sc.streams[streamID]; ok {
 		return st.state, st
 	}
@@ -701,6 +707,8 @@
 		select {
 		case wr := <-sc.wantWriteFrameCh:
 			sc.writeFrame(wr)
+		case spr := <-sc.wantStartPushCh:
+			sc.startPush(spr)
 		case res := <-sc.wroteFrameCh:
 			sc.wroteFrame(res)
 		case res := <-sc.readFrameCh:
@@ -881,6 +889,16 @@
 			panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wr))
 		}
 	}
+	if wpp, ok := wr.write.(*writePushPromise); ok {
+		var err error
+		wpp.promisedID, err = wpp.allocatePromisedID()
+		if err != nil {
+			if wr.done != nil {
+				wr.done <- err
+			}
+			return
+		}
+	}
 
 	sc.writingFrame = true
 	sc.needsFrameFlush = true
@@ -1204,8 +1222,12 @@
 		panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
 	}
 	st.state = stateClosed
-	sc.curOpenStreams--
-	if sc.curOpenStreams == 0 {
+	if st.isPushed() {
+		sc.curPushedStreams--
+	} else {
+		sc.curClientStreams--
+	}
+	if sc.curClientStreams+sc.curPushedStreams == 0 {
 		sc.setConnState(http.StateIdle)
 	}
 	delete(sc.streams, st.id)
@@ -1388,6 +1410,11 @@
 	return nil
 }
 
+// isPushed reports whether the stream is server-initiated.
+func (st *stream) isPushed() bool {
+	return st.id%2 == 0
+}
+
 // endStream closes a Request.Body's pipe. It is called when a DATA
 // frame says a request body is over (or after trailers).
 func (st *stream) endStream() {
@@ -1417,12 +1444,12 @@
 
 func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
 	sc.serveG.check()
-	id := f.Header().StreamID
+	id := f.StreamID
 	if sc.inGoAway {
 		// Ignore.
 		return nil
 	}
-	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
+	// http://tools.ietf.org/html/rfc7540#section-5.1.1
 	// Streams initiated by a client MUST use odd-numbered stream
 	// identifiers. [...] An endpoint that receives an unexpected
 	// stream identifier MUST respond with a connection error
@@ -1434,8 +1461,7 @@
 	// send a trailer for an open one. If we already have a stream
 	// open, let it process its own HEADERS frame (trailers at this
 	// point, if it's valid).
-	st := sc.streams[f.Header().StreamID]
-	if st != nil {
+	if st := sc.streams[f.StreamID]; st != nil {
 		return st.processTrailerHeaders(f)
 	}
 
@@ -1453,48 +1479,31 @@
 		sc.idleTimer.Stop()
 	}
 
-	ctx, cancelCtx := contextWithCancel(sc.baseCtx)
-	st = &stream{
-		sc:        sc,
-		id:        id,
-		state:     stateOpen,
-		ctx:       ctx,
-		cancelCtx: cancelCtx,
-	}
-	if f.StreamEnded() {
-		st.state = stateHalfClosedRemote
-	}
-	st.cw.Init()
-
-	st.flow.conn = &sc.flow // link to conn-level counter
-	st.flow.add(sc.initialWindowSize)
-	st.inflow.conn = &sc.inflow      // link to conn-level counter
-	st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings
-
-	sc.streams[id] = st
-	sc.writeSched.OpenStream(st.id, OpenStreamOptions{})
-	sc.curOpenStreams++
-	if sc.curOpenStreams == 1 {
-		sc.setConnState(http.StateActive)
-	}
-	if sc.curOpenStreams > sc.advMaxStreams {
-		// "Endpoints MUST NOT exceed the limit set by their
-		// peer. An endpoint that receives a HEADERS frame
-		// that causes their advertised concurrent stream
-		// limit to be exceeded MUST treat this as a stream
-		// error (Section 5.4.2) of type PROTOCOL_ERROR or
-		// REFUSED_STREAM."
+	// http://tools.ietf.org/html/rfc7540#section-5.1.2
+	// [...] Endpoints MUST NOT exceed the limit set by their peer. An
+	// endpoint that receives a HEADERS frame that causes their
+	// advertised concurrent stream limit to be exceeded MUST treat
+	// this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR
+	// or REFUSED_STREAM.
+	if sc.curClientStreams+1 > sc.advMaxStreams {
 		if sc.unackedSettings == 0 {
 			// They should know better.
-			return streamError(st.id, ErrCodeProtocol)
+			return streamError(id, ErrCodeProtocol)
 		}
 		// Assume it's a network race, where they just haven't
 		// received our last SETTINGS update. But actually
 		// this can't happen yet, because we don't yet provide
 		// a way for users to adjust server parameters at
 		// runtime.
-		return streamError(st.id, ErrCodeRefusedStream)
+		return streamError(id, ErrCodeRefusedStream)
 	}
+
+	initialState := stateOpen
+	if f.StreamEnded() {
+		initialState = stateHalfClosedRemote
+	}
+	st := sc.newStream(id, 0, initialState)
+
 	if f.HasPriority() {
 		if err := checkPriority(f.StreamID, f.Priority); err != nil {
 			return err
@@ -1517,7 +1526,7 @@
 	if f.Truncated {
 		// Their header list was too long. Send a 431 error.
 		handler = handleHeaderListTooLong
-	} else if err := checkValidHTTP2Request(req); err != nil {
+	} else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil {
 		handler = new400Handler(err)
 	}
 
@@ -1590,21 +1599,56 @@
 	return nil
 }
 
+func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream {
+	sc.serveG.check()
+	if id == 0 {
+		panic("internal error: cannot create stream with id 0")
+	}
+
+	ctx, cancelCtx := contextWithCancel(sc.baseCtx)
+	st := &stream{
+		sc:        sc,
+		id:        id,
+		state:     state,
+		ctx:       ctx,
+		cancelCtx: cancelCtx,
+	}
+	st.cw.Init()
+	st.flow.conn = &sc.flow // link to conn-level counter
+	st.flow.add(sc.initialWindowSize)
+	st.inflow.conn = &sc.inflow      // link to conn-level counter
+	st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings
+
+	sc.streams[id] = st
+	sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID})
+	if st.isPushed() {
+		sc.curPushedStreams++
+	} else {
+		sc.curClientStreams++
+	}
+	if sc.curClientStreams+sc.curPushedStreams == 1 {
+		sc.setConnState(http.StateActive)
+	}
+
+	return st
+}
+
 func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
 	sc.serveG.check()
 
-	method := f.PseudoValue("method")
-	path := f.PseudoValue("path")
-	scheme := f.PseudoValue("scheme")
-	authority := f.PseudoValue("authority")
+	rp := requestParam{
+		method:    f.PseudoValue("method"),
+		scheme:    f.PseudoValue("scheme"),
+		authority: f.PseudoValue("authority"),
+		path:      f.PseudoValue("path"),
+	}
 
-	isConnect := method == "CONNECT"
+	isConnect := rp.method == "CONNECT"
 	if isConnect {
-		if path != "" || scheme != "" || authority == "" {
+		if rp.path != "" || rp.scheme != "" || rp.authority == "" {
 			return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 		}
-	} else if method == "" || path == "" ||
-		(scheme != "https" && scheme != "http") {
+	} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
 		// See 8.1.2.6 Malformed Requests and Responses:
 		//
 		// Malformed requests or responses that are detected
@@ -1619,36 +1663,64 @@
 	}
 
 	bodyOpen := !f.StreamEnded()
-	if method == "HEAD" && bodyOpen {
+	if rp.method == "HEAD" && bodyOpen {
 		// HEAD requests can't have bodies
 		return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 	}
-	var tlsState *tls.ConnectionState // nil if not scheme https
 
-	if scheme == "https" {
+	rp.header = make(http.Header)
+	for _, hf := range f.RegularFields() {
+		rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+	}
+	if rp.authority == "" {
+		rp.authority = rp.header.Get("Host")
+	}
+
+	rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
+	if err != nil {
+		return nil, nil, err
+	}
+	if bodyOpen {
+		st.reqBuf = getRequestBodyBuf()
+		req.Body.(*requestBody).pipe = &pipe{
+			b: &fixedBuffer{buf: st.reqBuf},
+		}
+
+		if vv, ok := rp.header["Content-Length"]; ok {
+			req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
+		} else {
+			req.ContentLength = -1
+		}
+	}
+	return rw, req, nil
+}
+
+type requestParam struct {
+	method                  string
+	scheme, authority, path string
+	header                  http.Header
+}
+
+func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
+	sc.serveG.check()
+
+	var tlsState *tls.ConnectionState // nil if not scheme https
+	if rp.scheme == "https" {
 		tlsState = sc.tlsState
 	}
 
-	header := make(http.Header)
-	for _, hf := range f.RegularFields() {
-		header.Add(sc.canonicalHeader(hf.Name), hf.Value)
-	}
-
-	if authority == "" {
-		authority = header.Get("Host")
-	}
-	needsContinue := header.Get("Expect") == "100-continue"
+	needsContinue := rp.header.Get("Expect") == "100-continue"
 	if needsContinue {
-		header.Del("Expect")
+		rp.header.Del("Expect")
 	}
 	// Merge Cookie headers into one "; "-delimited value.
-	if cookies := header["Cookie"]; len(cookies) > 1 {
-		header.Set("Cookie", strings.Join(cookies, "; "))
+	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
+		rp.header.Set("Cookie", strings.Join(cookies, "; "))
 	}
 
 	// Setup Trailers
 	var trailer http.Header
-	for _, v := range header["Trailer"] {
+	for _, v := range rp.header["Trailer"] {
 		for _, key := range strings.Split(v, ",") {
 			key = http.CanonicalHeaderKey(strings.TrimSpace(key))
 			switch key {
@@ -1663,53 +1735,42 @@
 			}
 		}
 	}
-	delete(header, "Trailer")
+	delete(rp.header, "Trailer")
+
+	var url_ *url.URL
+	var requestURI string
+	if rp.method == "CONNECT" {
+		url_ = &url.URL{Host: rp.authority}
+		requestURI = rp.authority // mimic HTTP/1 server behavior
+	} else {
+		var err error
+		url_, err = url.ParseRequestURI(rp.path)
+		if err != nil {
+			return nil, nil, streamError(st.id, ErrCodeProtocol)
+		}
+		requestURI = rp.path
+	}
 
 	body := &requestBody{
 		conn:          sc,
 		stream:        st,
 		needsContinue: needsContinue,
 	}
-	var url_ *url.URL
-	var requestURI string
-	if isConnect {
-		url_ = &url.URL{Host: authority}
-		requestURI = authority // mimic HTTP/1 server behavior
-	} else {
-		var err error
-		url_, err = url.ParseRequestURI(path)
-		if err != nil {
-			return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
-		}
-		requestURI = path
-	}
 	req := &http.Request{
-		Method:     method,
+		Method:     rp.method,
 		URL:        url_,
 		RemoteAddr: sc.remoteAddrStr,
-		Header:     header,
+		Header:     rp.header,
 		RequestURI: requestURI,
 		Proto:      "HTTP/2.0",
 		ProtoMajor: 2,
 		ProtoMinor: 0,
 		TLS:        tlsState,
-		Host:       authority,
+		Host:       rp.authority,
 		Body:       body,
 		Trailer:    trailer,
 	}
 	req = requestWithContext(req, st.ctx)
-	if bodyOpen {
-		st.reqBuf = getRequestBodyBuf()
-		body.pipe = &pipe{
-			b: &fixedBuffer{buf: st.reqBuf},
-		}
-
-		if vv, ok := header["Content-Length"]; ok {
-			req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
-		} else {
-			req.ContentLength = -1
-		}
-	}
 
 	rws := responseWriterStatePool.Get().(*responseWriterState)
 	bwSave := rws.bw
@@ -2267,6 +2328,194 @@
 	responseWriterStatePool.Put(rws)
 }
 
+// Push errors.
+var (
+	ErrRecursivePush    = errors.New("http2: recursive push not allowed")
+	ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
+)
+
+// pushOptions is the internal version of http.PushOptions, which we
+// cannot include here because it's only defined in Go 1.8 and later.
+type pushOptions struct {
+	Method string
+	Header http.Header
+}
+
+func (w *responseWriter) push(target string, opts pushOptions) error {
+	st := w.rws.stream
+	sc := st.sc
+	sc.serveG.checkNotOn()
+
+	// No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream."
+	// http://tools.ietf.org/html/rfc7540#section-6.6
+	if st.isPushed() {
+		return ErrRecursivePush
+	}
+
+	// Default options.
+	if opts.Method == "" {
+		opts.Method = "GET"
+	}
+	if opts.Header == nil {
+		opts.Header = http.Header{}
+	}
+	wantScheme := "http"
+	if w.rws.req.TLS != nil {
+		wantScheme = "https"
+	}
+
+	// Validate the request.
+	u, err := url.Parse(target)
+	if err != nil {
+		return err
+	}
+	if u.Scheme == "" {
+		if !strings.HasPrefix(target, "/") {
+			return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
+		}
+		u.Scheme = wantScheme
+		u.Host = w.rws.req.Host
+	} else {
+		if u.Scheme != wantScheme {
+			return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
+		}
+		if u.Host == "" {
+			return errors.New("URL must have a host")
+		}
+	}
+	for k := range opts.Header {
+		if strings.HasPrefix(k, ":") {
+			return fmt.Errorf("promised request headers cannot include psuedo header %q", k)
+		}
+		// These headers are meaningful only if the request has a body,
+		// but PUSH_PROMISE requests cannot have a body.
+		// http://tools.ietf.org/html/rfc7540#section-8.2
+		// Also disallow Host, since the promised URL must be absolute.
+		switch strings.ToLower(k) {
+		case "content-length", "content-encoding", "trailer", "te", "expect", "host":
+			return fmt.Errorf("promised request headers cannot include %q", k)
+		}
+	}
+	if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil {
+		return err
+	}
+
+	// The RFC effectively limits promised requests to GET and HEAD:
+	// "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]"
+	// http://tools.ietf.org/html/rfc7540#section-8.2
+	if opts.Method != "GET" && opts.Method != "HEAD" {
+		return fmt.Errorf("method %q must be GET or HEAD", opts.Method)
+	}
+
+	msg := startPushRequest{
+		parent: st,
+		method: opts.Method,
+		url:    u,
+		header: cloneHeader(opts.Header),
+		done:   errChanPool.Get().(chan error),
+	}
+
+	select {
+	case <-sc.doneServing:
+		return errClientDisconnected
+	case <-st.cw:
+		return errStreamClosed
+	case sc.wantStartPushCh <- msg:
+	}
+
+	select {
+	case <-sc.doneServing:
+		return errClientDisconnected
+	case <-st.cw:
+		return errStreamClosed
+	case err := <-msg.done:
+		errChanPool.Put(msg.done)
+		return err
+	}
+}
+
+type startPushRequest struct {
+	parent *stream
+	method string
+	url    *url.URL
+	header http.Header
+	done   chan error
+}
+
+func (sc *serverConn) startPush(msg startPushRequest) {
+	sc.serveG.check()
+
+	// http://tools.ietf.org/html/rfc7540#section-6.6.
+	// PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that
+	// is in either the "open" or "half-closed (remote)" state.
+	if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote {
+		// responseWriter.Push checks that the stream is peer-initiaed.
+		msg.done <- errStreamClosed
+		return
+	}
+
+	// http://tools.ietf.org/html/rfc7540#section-6.6.
+	if !sc.pushEnabled {
+		msg.done <- http.ErrNotSupported
+		return
+	}
+
+	// PUSH_PROMISE frames must be sent in increasing order by stream ID, so
+	// we allocate an ID for the promised stream lazily, when the PUSH_PROMISE
+	// is written. Once the ID is allocated, we start the request handler.
+	allocatePromisedID := func() (uint32, error) {
+		sc.serveG.check()
+
+		// Check this again, just in case. Technically, we might have received
+		// an updated SETTINGS by the time we got around to writing this frame.
+		if !sc.pushEnabled {
+			return 0, http.ErrNotSupported
+		}
+		// http://tools.ietf.org/html/rfc7540#section-6.5.2.
+		if sc.curPushedStreams+1 > sc.clientMaxStreams {
+			return 0, ErrPushLimitReached
+		}
+
+		// http://tools.ietf.org/html/rfc7540#section-5.1.1.
+		// Streams initiated by the server MUST use even-numbered identifiers.
+		sc.maxPushPromiseID += 2
+		promisedID := sc.maxPushPromiseID
+
+		// http://tools.ietf.org/html/rfc7540#section-8.2.
+		// Strictly speaking, the new stream should start in "reserved (local)", then
+		// transition to "half closed (remote)" after sending the initial HEADERS, but
+		// we start in "half closed (remote)" for simplicity.
+		// See further comments at the definition of stateHalfClosedRemote.
+		promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
+		rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{
+			method:    msg.method,
+			scheme:    msg.url.Scheme,
+			authority: msg.url.Host,
+			path:      msg.url.RequestURI(),
+			header:    msg.header,
+		})
+		if err != nil {
+			// Should not happen, since we've already validated msg.url.
+			panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
+		}
+
+		go sc.runHandler(rw, req, sc.handler.ServeHTTP)
+		return promisedID, nil
+	}
+
+	sc.writeFrame(FrameWriteRequest{
+		write: &writePushPromise{
+			streamID:           msg.parent.id,
+			method:             msg.method,
+			url:                msg.url,
+			h:                  msg.header,
+			allocatePromisedID: allocatePromisedID,
+		},
+		stream: msg.parent,
+		done:   msg.done,
+	})
+}
+
 // foreachHeaderElement splits v according to the "#rule" construction
 // in RFC 2616 section 2.1 and calls fn for each non-empty element.
 func foreachHeaderElement(v string, fn func(string)) {
@@ -2294,16 +2543,16 @@
 	"Upgrade",
 }
 
-// checkValidHTTP2Request checks whether req is a valid HTTP/2 request,
+// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request,
 // per RFC 7540 Section 8.1.2.2.
 // The returned error is reported to users.
-func checkValidHTTP2Request(req *http.Request) error {
-	for _, h := range connHeaders {
-		if _, ok := req.Header[h]; ok {
-			return fmt.Errorf("request header %q is not valid in HTTP/2", h)
+func checkValidHTTP2RequestHeaders(h http.Header) error {
+	for _, k := range connHeaders {
+		if _, ok := h[k]; ok {
+			return fmt.Errorf("request header %q is not valid in HTTP/2", k)
 		}
 	}
-	te := req.Header["Te"]
+	te := h["Te"]
 	if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
 		return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
 	}