Refactor frame writing in prep for the write scheduler and client support.
diff --git a/server.go b/server.go
index 663c693..921eb6c 100644
--- a/server.go
+++ b/server.go
@@ -236,8 +236,7 @@
shutdownTimerCh <-chan time.Time // nil until used
shutdownTimer *time.Timer // nil until used
- // Owned by the writeFrameAsync goroutine; use writeG.check():
- writeG goroutineLock // used to verify things running on writeFrameAsync
+ // Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
hpackEncoder *hpack.Encoder
}
@@ -279,6 +278,13 @@
gotReset bool // only true once detacted from streams map
}
+func (sc *serverConn) Framer() *Framer { return sc.framer }
+func (sc *serverConn) CloseConn() error { return sc.conn.Close() }
+func (sc *serverConn) Flush() error { return sc.bw.Flush() }
+func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
+ return sc.hpackEncoder, &sc.headerWriteBuf
+}
+
func (sc *serverConn) state(streamID uint32) streamState {
sc.serveG.check()
// http://http2.github.io/http2-spec/#rfc.section.5.1
@@ -418,12 +424,7 @@
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) {
- sc.writeG = newGoroutineLock()
- var streamID uint32
- if wm.stream != nil {
- streamID = wm.stream.id
- }
- err := wm.write(sc, streamID, wm.v)
+ err := wm.write(sc, wm.v)
if ch := wm.done; ch != nil {
select {
case ch <- err:
@@ -434,11 +435,6 @@
sc.wroteFrameCh <- struct{}{} // tickle frame selection scheduler
}
-func (sc *serverConn) flushFrameWriter(uint32, interface{}) error {
- sc.writeG.check()
- return sc.bw.Flush() // may block on the network
-}
-
func (sc *serverConn) closeAllStreamsOnConnClose() {
sc.serveG.check()
for _, st := range sc.streams {
@@ -462,7 +458,14 @@
sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
- sc.writeFrame(frameWriteMsg{write: (*serverConn).sendInitialSettings})
+ sc.writeFrame(frameWriteMsg{
+ write: writeSettings,
+ v: []Setting{
+ {SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
+ {SettingMaxConcurrentStreams, sc.advMaxStreams},
+ /* TODO: more actual settings */
+ },
+ })
if err := sc.readPreface(); err != nil {
sc.condlogf(err, "error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
@@ -502,15 +505,6 @@
}
}
-func (sc *serverConn) sendInitialSettings(uint32, interface{}) error {
- sc.writeG.check()
- return sc.framer.WriteSettings(
- Setting{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
- Setting{SettingMaxConcurrentStreams, sc.advMaxStreams},
- /* TODO: more actual settings */
- )
-}
-
// readPreface reads the ClientPreface greeting from the peer
// or returns an error on timeout or an invalid greeting.
func (sc *serverConn) readPreface() error {
@@ -554,7 +548,7 @@
func (sc *serverConn) writeData(stream *stream, data *dataWriteParams, ch chan error) error {
sc.serveG.checkNotOn() // NOT on; otherwise could deadlock in sc.writeFrame
sc.writeFrameFromHandler(frameWriteMsg{
- write: (*serverConn).writeDataFrame,
+ write: writeDataFrame,
cost: uint32(len(data.p)),
stream: stream,
endStream: data.end,
@@ -661,7 +655,7 @@
if sc.needToSendGoAway {
sc.needToSendGoAway = false
sc.startFrameWrite(frameWriteMsg{
- write: (*serverConn).writeGoAwayFrame,
+ write: writeGoAwayFrame,
v: &goAwayParams{
maxStreamID: sc.maxStreamID,
code: sc.goAwayCode,
@@ -670,7 +664,7 @@
return
}
if sc.writeSched.empty() && sc.needsFrameFlush {
- sc.startFrameWrite(frameWriteMsg{write: (*serverConn).flushFrameWriter})
+ sc.startFrameWrite(frameWriteMsg{write: flushFrameWriter})
sc.needsFrameFlush = false // after startFrameWrite, since it sets this true
return
}
@@ -680,7 +674,7 @@
}
if sc.needToSendSettingsAck {
sc.needToSendSettingsAck = false
- sc.startFrameWrite(frameWriteMsg{write: (*serverConn).writeSettingsAck})
+ sc.startFrameWrite(frameWriteMsg{write: writeSettingsAck})
return
}
if sc.writeSched.empty() {
@@ -716,18 +710,6 @@
sc.shutdownTimerCh = sc.shutdownTimer.C
}
-func (sc *serverConn) writeGoAwayFrame(_ uint32, v interface{}) error {
- sc.writeG.check()
- p := v.(*goAwayParams)
- err := sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
- if p.code != 0 {
- sc.bw.Flush() // ignore error: we're hanging up on them anyway
- time.Sleep(50 * time.Millisecond)
- sc.conn.Close()
- }
- return err
-}
-
func (sc *serverConn) resetStream(se StreamError) {
sc.serveG.check()
st, ok := sc.streams[se.StreamID]
@@ -735,19 +717,13 @@
panic("internal package error; resetStream called on non-existent stream")
}
sc.writeFrame(frameWriteMsg{
- write: (*serverConn).writeRSTStreamFrame,
+ write: writeRSTStreamFrame,
v: &se,
})
st.sentReset = true
sc.closeStream(st, se)
}
-func (sc *serverConn) writeRSTStreamFrame(streamID uint32, v interface{}) error {
- sc.writeG.check()
- se := v.(*StreamError)
- return sc.framer.WriteRSTStream(se.StreamID, se.Code)
-}
-
// curHeaderStreamID returns the stream ID of the header block we're
// currently in the middle of reading. If this returns non-zero, the
// next frame must be a CONTINUATION with this stream id.
@@ -871,18 +847,12 @@
return ConnectionError(ErrCodeProtocol)
}
sc.writeFrame(frameWriteMsg{
- write: (*serverConn).writePingAck,
+ write: writePingAck,
v: f,
})
return nil
}
-func (sc *serverConn) writePingAck(_ uint32, v interface{}) error {
- sc.writeG.check()
- pf := v.(*PingFrame) // contains the data we need to write back
- return sc.framer.WritePing(true, pf.Data)
-}
-
func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
sc.serveG.check()
switch {
@@ -961,11 +931,6 @@
return nil
}
-func (sc *serverConn) writeSettingsAck(uint32, interface{}) error {
- sc.writeG.check()
- return sc.framer.WriteSettingsAck()
-}
-
func (sc *serverConn) processSetting(s Setting) error {
sc.serveG.check()
if err := s.Valid(); err != nil {
@@ -1263,21 +1228,6 @@
sc.handler.ServeHTTP(rw, req)
}
-type frameWriteMsg struct {
- // write runs on the writeFrameAsync goroutine.
- write func(sc *serverConn, streamID uint32, v interface{}) error
-
- v interface{} // passed to write
- cost uint32 // number of flow control bytes required
- stream *stream // used for prioritization
- endStream bool // streamID is being closed locally
-
- // done, if non-nil, must be a buffered channel with space for
- // 1 message and is sent the return value from write (or an
- // earlier error) when the frame has been written.
- done chan error
-}
-
// headerWriteReq is a request to write an HTTP response header from a server Handler.
type headerWriteReq struct {
stream *stream
@@ -1302,7 +1252,7 @@
errc = tempCh
}
sc.writeFrameFromHandler(frameWriteMsg{
- write: (*serverConn).writeHeadersFrame,
+ write: writeHeadersFrame,
v: req,
stream: req.stream,
done: errc,
@@ -1319,91 +1269,16 @@
}
}
-func (sc *serverConn) writeHeadersFrame(streamID uint32, v interface{}) error {
- sc.writeG.check()
- req := v.(headerWriteReq)
-
- sc.headerWriteBuf.Reset()
- sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(req.httpResCode)})
- for k, vv := range req.h {
- k = lowerHeader(k)
- for _, v := range vv {
- // TODO: more of "8.1.2.2 Connection-Specific Header Fields"
- if k == "transfer-encoding" && v != "trailers" {
- continue
- }
- sc.hpackEncoder.WriteField(hpack.HeaderField{Name: k, Value: v})
- }
- }
- if req.contentType != "" {
- sc.hpackEncoder.WriteField(hpack.HeaderField{Name: "content-type", Value: req.contentType})
- }
- if req.contentLength != "" {
- sc.hpackEncoder.WriteField(hpack.HeaderField{Name: "content-length", Value: req.contentLength})
- }
-
- headerBlock := sc.headerWriteBuf.Bytes()
- if len(headerBlock) == 0 {
- panic("unexpected empty hpack")
- }
- first := true
- for len(headerBlock) > 0 {
- frag := headerBlock
- if len(frag) > int(sc.maxWriteFrameSize) {
- frag = frag[:sc.maxWriteFrameSize]
- }
- headerBlock = headerBlock[len(frag):]
- endHeaders := len(headerBlock) == 0
- var err error
- if first {
- first = false
- err = sc.framer.WriteHeaders(HeadersFrameParam{
- StreamID: req.stream.id,
- BlockFragment: frag,
- EndStream: req.endStream,
- EndHeaders: endHeaders,
- })
- } else {
- err = sc.framer.WriteContinuation(req.stream.id, endHeaders, frag)
- }
- if err != nil {
- return err
- }
- }
- return nil
-}
-
// called from handler goroutines.
func (sc *serverConn) write100ContinueHeaders(st *stream) {
sc.serveG.checkNotOn() // NOT
sc.writeFrameFromHandler(frameWriteMsg{
- write: (*serverConn).write100ContinueHeadersFrame,
+ write: write100ContinueHeadersFrame,
+ v: st,
stream: st,
})
}
-func (sc *serverConn) write100ContinueHeadersFrame(streamID uint32, _ interface{}) error {
- sc.writeG.check()
- sc.headerWriteBuf.Reset()
- sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
- return sc.framer.WriteHeaders(HeadersFrameParam{
- StreamID: streamID,
- BlockFragment: sc.headerWriteBuf.Bytes(),
- EndStream: false,
- EndHeaders: true,
- })
-}
-
-func (sc *serverConn) writeDataFrame(streamID uint32, v interface{}) error {
- sc.writeG.check()
- req := v.(*dataWriteParams)
- return sc.framer.WriteData(streamID, req.end, req.p)
-}
-
-type windowUpdateReq struct {
- n uint32
-}
-
// called from handler goroutines
func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
sc.serveG.checkNotOn() // NOT
@@ -1413,33 +1288,21 @@
const maxUint32 = 2147483647
for n >= maxUint32 {
sc.writeFrameFromHandler(frameWriteMsg{
- write: (*serverConn).sendWindowUpdateInLoop,
- v: windowUpdateReq{maxUint32},
+ write: writeWindowUpdate,
+ v: windowUpdateReq{streamID: st.id, n: maxUint32},
stream: st,
})
n -= maxUint32
}
if n > 0 {
sc.writeFrameFromHandler(frameWriteMsg{
- write: (*serverConn).sendWindowUpdateInLoop,
- v: windowUpdateReq{uint32(n)},
+ write: writeWindowUpdate,
+ v: windowUpdateReq{streamID: st.id, n: uint32(n)},
stream: st,
})
}
}
-func (sc *serverConn) sendWindowUpdateInLoop(streamID uint32, v interface{}) error {
- sc.writeG.check()
- wu := v.(windowUpdateReq)
- if err := sc.framer.WriteWindowUpdate(0, wu.n); err != nil {
- return err
- }
- if err := sc.framer.WriteWindowUpdate(streamID, wu.n); err != nil {
- return err
- }
- return nil
-}
-
type requestBody struct {
stream *stream
closed bool
@@ -1511,6 +1374,7 @@
}
func (rws *responseWriterState) writeData(p []byte, end bool) error {
+ rws.curWrite.streamID = rws.stream.id
rws.curWrite.p = p
rws.curWrite.end = end
return rws.stream.conn.writeData(rws.stream, &rws.curWrite, rws.frameWriteCh)