http2: add Server support for reading trailers from clients
Updates golang/go#13557
Change-Id: I95bbb15d9abbbbc4dc6c3a22cd965d8dcef53fb8
Reviewed-on: https://go-review.googlesource.com/17891
Reviewed-by: Blake Mizerany <blake.mizerany@gmail.com>
diff --git a/http2/server.go b/http2/server.go
index 8d5f7cd..238c186 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -224,7 +224,7 @@
sc.flow.add(initialWindowSize)
sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
- sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField)
+ sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, nil)
sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
fr := NewFramer(sc.bw, c)
@@ -411,20 +411,26 @@
// responseWriter's state field.
type stream struct {
// immutable:
+ sc *serverConn
id uint32
body *pipe // non-nil if expecting DATA frames
cw closeWaiter // closed wait stream transitions to closed state
// owned by serverConn's serve loop:
- bodyBytes int64 // body bytes seen so far
- declBodyBytes int64 // or -1 if undeclared
- flow flow // limits writing from Handler to client
- inflow flow // what the client is allowed to POST/etc to us
- parent *stream // or nil
- weight uint8
- state streamState
- sentReset bool // only true once detached from streams map
- gotReset bool // only true once detacted from streams map
+ bodyBytes int64 // body bytes seen so far
+ declBodyBytes int64 // or -1 if undeclared
+ flow flow // limits writing from Handler to client
+ inflow flow // what the client is allowed to POST/etc to us
+ parent *stream // or nil
+ numTrailerValues int64
+ weight uint8
+ state streamState
+ sentReset bool // only true once detached from streams map
+ gotReset bool // only true once detacted from streams map
+ gotTrailerHeader bool // HEADER frame for trailers was seen
+
+ trailer http.Header // accumulated trailers
+ reqTrailer http.Header // handler's Request.Trailer
}
func (sc *serverConn) Framer() *Framer { return sc.framer }
@@ -537,6 +543,37 @@
}
}
+func (st *stream) onNewTrailerField(f hpack.HeaderField) {
+ sc := st.sc
+ sc.serveG.check()
+ sc.vlogf("got trailer field %+v", f)
+ switch {
+ case !validHeader(f.Name):
+ // TODO: change hpack signature so this can return
+ // errors? Or stash an error somewhere on st or sc
+ // for processHeaderBlockFragment etc to pick up and
+ // return after the hpack Write/Close. For now just
+ // ignore.
+ return
+ case strings.HasPrefix(f.Name, ":"):
+ // TODO: same TODO as above.
+ return
+ default:
+ key := sc.canonicalHeader(f.Name)
+ if st.trailer != nil {
+ vv := append(st.trailer[key], f.Value)
+ st.trailer[key] = vv
+
+ // arbitrary; TODO: read spec about header list size limits wrt trailers
+ const tooBig = 1000
+ if len(vv) >= tooBig {
+ sc.hpackDecoder.SetEmitEnabled(false)
+ }
+
+ }
+ }
+}
+
func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check()
cv, ok := commonCanonHeader[v]
@@ -1249,7 +1286,7 @@
// with a stream error (Section 5.4.2) of type STREAM_CLOSED."
id := f.Header().StreamID
st, ok := sc.streams[id]
- if !ok || st.state != stateOpen {
+ if !ok || st.state != stateOpen || st.gotTrailerHeader {
// This includes sending a RST_STREAM if the stream is
// in stateHalfClosedLocal (which currently means that
// the http.Handler returned, so it's done reading &
@@ -1283,17 +1320,38 @@
st.bodyBytes += int64(len(data))
}
if f.StreamEnded() {
- if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
- st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
- st.declBodyBytes, st.bodyBytes))
- } else {
- st.body.CloseWithError(io.EOF)
- }
- st.state = stateHalfClosedRemote
+ st.endStream()
}
return nil
}
+// endStream closes a Request.Body's pipe. It is called when a DATA
+// frame says a request body is over (or after trailers).
+func (st *stream) endStream() {
+ sc := st.sc
+ sc.serveG.check()
+
+ if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
+ st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
+ st.declBodyBytes, st.bodyBytes))
+ } else {
+ st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
+ st.body.CloseWithError(io.EOF)
+ }
+ st.state = stateHalfClosedRemote
+}
+
+// copyTrailersToHandlerRequest is run in the Handler's goroutine in
+// its Request.Body.Read just before it gets io.EOF.
+func (st *stream) copyTrailersToHandlerRequest() {
+ for k, vv := range st.trailer {
+ if _, ok := st.reqTrailer[k]; ok {
+ // Only copy it over it was pre-declared.
+ st.reqTrailer[k] = vv
+ }
+ }
+}
+
func (sc *serverConn) processHeaders(f *HeadersFrame) error {
sc.serveG.check()
id := f.Header().StreamID
@@ -1302,20 +1360,36 @@
return nil
}
// http://http2.github.io/http2-spec/#rfc.section.5.1.1
- if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil {
- // Streams initiated by a client MUST use odd-numbered
- // stream identifiers. [...] The identifier of a newly
- // established stream MUST be numerically greater than all
- // streams that the initiating endpoint has opened or
- // reserved. [...] An endpoint that receives an unexpected
- // stream identifier MUST respond with a connection error
- // (Section 5.4.1) of type PROTOCOL_ERROR.
+ // Streams initiated by a client MUST use odd-numbered stream
+ // identifiers. [...] An endpoint that receives an unexpected
+ // stream identifier MUST respond with a connection error
+ // (Section 5.4.1) of type PROTOCOL_ERROR.
+ if id%2 != 1 {
return ConnectionError(ErrCodeProtocol)
}
+ // A HEADERS frame can be used to create a new stream or
+ // send a trailer for an open one. If we already have a stream
+ // open, let it process its own HEADERS frame (trailers at this
+ // point, if it's valid).
+ st := sc.streams[f.Header().StreamID]
+ if st != nil {
+ return st.processTrailerHeaders(f)
+ }
+
+ // [...] The identifier of a newly established stream MUST be
+ // numerically greater than all streams that the initiating
+ // endpoint has opened or reserved. [...] An endpoint that
+ // receives an unexpected stream identifier MUST respond with
+ // a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
+ if id <= sc.maxStreamID || sc.req.stream != nil {
+ return ConnectionError(ErrCodeProtocol)
+ }
+
if id > sc.maxStreamID {
sc.maxStreamID = id
}
- st := &stream{
+ st = &stream{
+ sc: sc,
id: id,
state: stateOpen,
}
@@ -1341,16 +1415,30 @@
stream: st,
header: make(http.Header),
}
+ sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
sc.hpackDecoder.SetEmitEnabled(true)
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
}
+func (st *stream) processTrailerHeaders(f *HeadersFrame) error {
+ sc := st.sc
+ sc.serveG.check()
+ if st.gotTrailerHeader {
+ return ConnectionError(ErrCodeProtocol)
+ }
+ st.gotTrailerHeader = true
+ return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+}
+
func (sc *serverConn) processContinuation(f *ContinuationFrame) error {
sc.serveG.check()
st := sc.streams[f.Header().StreamID]
if st == nil || sc.curHeaderStreamID() != st.id {
return ConnectionError(ErrCodeProtocol)
}
+ if st.gotTrailerHeader {
+ return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+ }
return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
}
@@ -1389,6 +1477,10 @@
if err != nil {
return err
}
+ st.reqTrailer = req.Trailer
+ if st.reqTrailer != nil {
+ st.trailer = make(http.Header)
+ }
st.body = req.Body.(*requestBody).pipe // may be nil
st.declBodyBytes = req.ContentLength
@@ -1402,6 +1494,24 @@
return nil
}
+func (st *stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
+ sc := st.sc
+ sc.serveG.check()
+ sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
+ if _, err := sc.hpackDecoder.Write(frag); err != nil {
+ return ConnectionError(ErrCodeCompression)
+ }
+ if !end {
+ return nil
+ }
+ err := sc.hpackDecoder.Close()
+ st.endStream()
+ if err != nil {
+ return ConnectionError(ErrCodeCompression)
+ }
+ return nil
+}
+
func (sc *serverConn) processPriority(f *PriorityFrame) error {
adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam)
return nil
@@ -1489,6 +1599,26 @@
if cookies := rp.header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; "))
}
+
+ // Setup Trailers
+ var trailer http.Header
+ for _, v := range rp.header["Trailer"] {
+ for _, key := range strings.Split(v, ",") {
+ key = http.CanonicalHeaderKey(strings.TrimSpace(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ // Bogus. (copy of http1 rules)
+ // Ignore.
+ default:
+ if trailer == nil {
+ trailer = make(http.Header)
+ }
+ trailer[key] = nil
+ }
+ }
+ }
+ delete(rp.header, "Trailer")
+
body := &requestBody{
conn: sc,
stream: rp.stream,
@@ -1512,10 +1642,11 @@
TLS: tlsState,
Host: authority,
Body: body,
+ Trailer: trailer,
}
if bodyOpen {
body.pipe = &pipe{
- b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX
+ b: &fixedBuffer{buf: make([]byte, initialWindowSize)}, // TODO: garbage
}
if vv, ok := rp.header["Content-Length"]; ok {