http2: add Server support for reading trailers from clients

Updates golang/go#13557

Change-Id: I95bbb15d9abbbbc4dc6c3a22cd965d8dcef53fb8
Reviewed-on: https://go-review.googlesource.com/17891
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
diff --git a/http2/server.go b/http2/server.go
index 8d5f7cd..238c186 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -224,7 +224,7 @@
 	sc.flow.add(initialWindowSize)
 	sc.inflow.add(initialWindowSize)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
-	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
+	sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
 	sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
 
 	fr := NewFramer(sc.bw, c)
@@ -411,20 +411,26 @@
 // responseWriter's state field.
 type stream struct {
 	// immutable:
+	sc   *serverConn
 	id   uint32
 	body *pipe       // non-nil if expecting DATA frames
 	cw   closeWaiter // closed wait stream transitions to closed state
 
 	// owned by serverConn's serve loop:
-	bodyBytes     int64   // body bytes seen so far
-	declBodyBytes int64   // or -1 if undeclared
-	flow          flow    // limits writing from Handler to client
-	inflow        flow    // what the client is allowed to POST/etc to us
-	parent        *stream // or nil
-	weight        uint8
-	state         streamState
-	sentReset     bool // only true once detached from streams map
-	gotReset      bool // only true once detacted from streams map
+	bodyBytes        int64   // body bytes seen so far
+	declBodyBytes    int64   // or -1 if undeclared
+	flow             flow    // limits writing from Handler to client
+	inflow           flow    // what the client is allowed to POST/etc to us
+	parent           *stream // or nil
+	numTrailerValues int64
+	weight           uint8
+	state            streamState
+	sentReset        bool // only true once detached from streams map
+	gotReset         bool // only true once detacted from streams map
+	gotTrailerHeader bool // HEADER frame for trailers was seen
+
+	trailer    http.Header // accumulated trailers
+	reqTrailer http.Header // handler's Request.Trailer
 }
 
 func (sc *serverConn) Framer() *Framer  { return sc.framer }
@@ -537,6 +543,37 @@
 	}
 }
 
+func (st *stream) onNewTrailerField(f hpack.HeaderField) {
+	sc := st.sc
+	sc.serveG.check()
+	sc.vlogf("got trailer field %+v", f)
+	switch {
+	case !validHeader(f.Name):
+		// TODO: change hpack signature so this can return
+		// errors?  Or stash an error somewhere on st or sc
+		// for processHeaderBlockFragment etc to pick up and
+		// return after the hpack Write/Close.  For now just
+		// ignore.
+		return
+	case strings.HasPrefix(f.Name, ":"):
+		// TODO: same TODO as above.
+		return
+	default:
+		key := sc.canonicalHeader(f.Name)
+		if st.trailer != nil {
+			vv := append(st.trailer[key], f.Value)
+			st.trailer[key] = vv
+
+			// arbitrary; TODO: read spec about header list size limits wrt trailers
+			const tooBig = 1000
+			if len(vv) >= tooBig {
+				sc.hpackDecoder.SetEmitEnabled(false)
+			}
+
+		}
+	}
+}
+
 func (sc *serverConn) canonicalHeader(v string) string {
 	sc.serveG.check()
 	cv, ok := commonCanonHeader[v]
@@ -1249,7 +1286,7 @@
 	// with a stream error (Section 5.4.2) of type STREAM_CLOSED."
 	id := f.Header().StreamID
 	st, ok := sc.streams[id]
-	if !ok || st.state != stateOpen {
+	if !ok || st.state != stateOpen || st.gotTrailerHeader {
 		// This includes sending a RST_STREAM if the stream is
 		// in stateHalfClosedLocal (which currently means that
 		// the http.Handler returned, so it's done reading &
@@ -1283,17 +1320,38 @@
 		st.bodyBytes += int64(len(data))
 	}
 	if f.StreamEnded() {
-		if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
-			st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
-				st.declBodyBytes, st.bodyBytes))
-		} else {
-			st.body.CloseWithError(io.EOF)
-		}
-		st.state = stateHalfClosedRemote
+		st.endStream()
 	}
 	return nil
 }
 
+// endStream closes a Request.Body's pipe. It is called when a DATA
+// frame says a request body is over (or after trailers).
+func (st *stream) endStream() {
+	sc := st.sc
+	sc.serveG.check()
+
+	if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
+		st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
+			st.declBodyBytes, st.bodyBytes))
+	} else {
+		st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
+		st.body.CloseWithError(io.EOF)
+	}
+	st.state = stateHalfClosedRemote
+}
+
+// copyTrailersToHandlerRequest is run in the Handler's goroutine in
+// its Request.Body.Read just before it gets io.EOF.
+func (st *stream) copyTrailersToHandlerRequest() {
+	for k, vv := range st.trailer {
+		if _, ok := st.reqTrailer[k]; ok {
+			// Only copy it over it was pre-declared.
+			st.reqTrailer[k] = vv
+		}
+	}
+}
+
 func (sc *serverConn) processHeaders(f *HeadersFrame) error {
 	sc.serveG.check()
 	id := f.Header().StreamID
@@ -1302,20 +1360,36 @@
 		return nil
 	}
 	// http://http2.github.io/http2-spec/#rfc.section.5.1.1
-	if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil {
-		// Streams initiated by a client MUST use odd-numbered
-		// stream identifiers. [...] The identifier of a newly
-		// established stream MUST be numerically greater than all
-		// streams that the initiating endpoint has opened or
-		// reserved. [...]  An endpoint that receives an unexpected
-		// stream identifier MUST respond with a connection error
-		// (Section 5.4.1) of type PROTOCOL_ERROR.
+	// Streams initiated by a client MUST use odd-numbered stream
+	// identifiers. [...] An endpoint that receives an unexpected
+	// stream identifier MUST respond with a connection error
+	// (Section 5.4.1) of type PROTOCOL_ERROR.
+	if id%2 != 1 {
 		return ConnectionError(ErrCodeProtocol)
 	}
+	// A HEADERS frame can be used to create a new stream or
+	// send a trailer for an open one. If we already have a stream
+	// open, let it process its own HEADERS frame (trailers at this
+	// point, if it's valid).
+	st := sc.streams[f.Header().StreamID]
+	if st != nil {
+		return st.processTrailerHeaders(f)
+	}
+
+	// [...] The identifier of a newly established stream MUST be
+	// numerically greater than all streams that the initiating
+	// endpoint has opened or reserved. [...]  An endpoint that
+	// receives an unexpected stream identifier MUST respond with
+	// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
+	if id <= sc.maxStreamID || sc.req.stream != nil {
+		return ConnectionError(ErrCodeProtocol)
+	}
+
 	if id > sc.maxStreamID {
 		sc.maxStreamID = id
 	}
-	st := &stream{
+	st = &stream{
+		sc:    sc,
 		id:    id,
 		state: stateOpen,
 	}
@@ -1341,16 +1415,30 @@
 		stream: st,
 		header: make(http.Header),
 	}
+	sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
 	sc.hpackDecoder.SetEmitEnabled(true)
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
+func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
+	sc := st.sc
+	sc.serveG.check()
+	if st.gotTrailerHeader {
+		return ConnectionError(ErrCodeProtocol)
+	}
+	st.gotTrailerHeader = true
+	return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+}
+
 func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
 	sc.serveG.check()
 	st := sc.streams[f.Header().StreamID]
 	if st == nil || sc.curHeaderStreamID() != st.id {
 		return ConnectionError(ErrCodeProtocol)
 	}
+	if st.gotTrailerHeader {
+		return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+	}
 	return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
 }
 
@@ -1389,6 +1477,10 @@
 	if err != nil {
 		return err
 	}
+	st.reqTrailer = req.Trailer
+	if st.reqTrailer != nil {
+		st.trailer = make(http.Header)
+	}
 	st.body = req.Body.(*requestBody).pipe // may be nil
 	st.declBodyBytes = req.ContentLength
 
@@ -1402,6 +1494,24 @@
 	return nil
 }
 
+func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
+	sc := st.sc
+	sc.serveG.check()
+	sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
+	if _, err := sc.hpackDecoder.Write(frag); err != nil {
+		return ConnectionError(ErrCodeCompression)
+	}
+	if !end {
+		return nil
+	}
+	err := sc.hpackDecoder.Close()
+	st.endStream()
+	if err != nil {
+		return ConnectionError(ErrCodeCompression)
+	}
+	return nil
+}
+
 func (sc *serverConn) processPriority(f *PriorityFrame) error {
 	adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam)
 	return nil
@@ -1489,6 +1599,26 @@
 	if cookies := rp.header["Cookie"]; len(cookies) > 1 {
 		rp.header.Set("Cookie", strings.Join(cookies, "; "))
 	}
+
+	// Setup Trailers
+	var trailer http.Header
+	for _, v := range rp.header["Trailer"] {
+		for _, key := range strings.Split(v, ",") {
+			key = http.CanonicalHeaderKey(strings.TrimSpace(key))
+			switch key {
+			case "Transfer-Encoding", "Trailer", "Content-Length":
+				// Bogus. (copy of http1 rules)
+				// Ignore.
+			default:
+				if trailer == nil {
+					trailer = make(http.Header)
+				}
+				trailer[key] = nil
+			}
+		}
+	}
+	delete(rp.header, "Trailer")
+
 	body := &requestBody{
 		conn:          sc,
 		stream:        rp.stream,
@@ -1512,10 +1642,11 @@
 		TLS:        tlsState,
 		Host:       authority,
 		Body:       body,
+		Trailer:    trailer,
 	}
 	if bodyOpen {
 		body.pipe = &pipe{
-			b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX
+			b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage
 		}
 
 		if vv, ok := rp.header["Content-Length"]; ok {