Refactor DATA writing from Handlers in prep for more flow control work.
diff --git a/server.go b/server.go
index 405593f..253520a 100644
--- a/server.go
+++ b/server.go
@@ -369,7 +369,11 @@
func (sc *serverConn) writeFrames() {
sc.writeG = newGoroutineLock()
for wm := range sc.writeFrameCh {
- err := wm.write(sc, wm.v)
+ var streamID uint32
+ if wm.stream != nil {
+ streamID = wm.stream.id
+ }
+ err := wm.write(sc, streamID, wm.v)
if ch := wm.done; ch != nil {
select {
case ch <- err:
@@ -381,7 +385,7 @@
}
}
-func (sc *serverConn) flushFrameWriter(_ interface{}) error {
+func (sc *serverConn) flushFrameWriter(uint32, interface{}) error {
sc.writeG.check()
return sc.bw.Flush() // may block on the network
}
@@ -453,7 +457,7 @@
}
}
-func (sc *serverConn) sendInitialSettings(_ interface{}) error {
+func (sc *serverConn) sendInitialSettings(uint32, interface{}) error {
sc.writeG.check()
return sc.framer.WriteSettings(
Setting{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
@@ -491,8 +495,42 @@
}
}
-// should be called from non-serve() goroutines, otherwise the ends may deadlock
-// the serve loop. (it's only buffered a little bit).
+// writeData writes the data described in req to stream.id.
+//
+// The provided ch is used to avoid allocating new channels for each
+// write operation. It's expected that the caller reuses req and ch
+// over time.
+func (sc *serverConn) writeData(stream *stream, req *dataWriteRequest, ch chan error) error {
+ sc.serveG.checkNotOn() // otherwise could deadlock in sc.writeFrame
+
+ // TODO: wait for flow control tokens. instead of writing a
+ // frame directly, add a new "write data" channel to the serve
+ // loop and modify the frame scheduler there to write chunks
+ // of req as tokens allow. Don't necessarily write it all at
+ // once in one frame.
+ sc.writeFrame(frameWriteMsg{
+ write: (*serverConn).writeDataFrame,
+ cost: uint32(len(req.p)),
+ stream: stream,
+ endStream: req.end,
+ v: req,
+ done: ch,
+ })
+ select {
+ case err := <-ch:
+ return err
+ case <-sc.doneServing:
+ return errClientDisconnected
+ }
+}
+
+// writeFrame sends wm to sc.wantWriteFrameCh, but aborts if the
+// connection has gone away.
+//
+// This must not be run from the serve goroutine itself, else it might
+// deadlock writing to sc.wantWriteFrameCh (which is only mildly
+// buffered and is read by serve itself). If you're on the serve
+// goroutine, call enqueueFrameWrite instead.
func (sc *serverConn) writeFrame(wm frameWriteMsg) {
sc.serveG.checkNotOn() // NOT
select {
@@ -502,6 +540,12 @@
}
}
+// enqueueFrameWrite either sends wm to the writeFrames goroutine, or
+// enqueues it for the future (with no pushback; the serve goroutine
+// never blocks!), for sending when the currently-being-written frame
+// is done writing.
+//
+// If you're not on the serve goroutine, use writeFrame instead.
func (sc *serverConn) enqueueFrameWrite(wm frameWriteMsg) {
sc.serveG.check()
// Fast path for common case:
@@ -605,6 +649,25 @@
return
}
+ // TODO:
+ // -- prioritize all non-DATA frames first. they're not flow controlled anyway and
+ // they're generally more important.
+ // -- for all DATA frames that are enqueued (and we should enqueue []byte instead of FRAMES),
+ // go over each (in priority order, as determined by the whole priority tree chaos),
+ // and decide which we have tokens for, and how many tokens.
+
+ // Writing on stream X requires that we have tokens on the
+ // stream 0 (the conn-as-a-whole stream) as well as stream X.
+
+ // So: find the highest priority stream X, then see: do we
+ // have tokens for X? Let's say we have N_X tokens. Then we should
+ // write MIN(N_X, TOKENS(conn-wide-tokens)).
+ //
+ // Any tokens left over? Repeat. Well, not really... the
+ // repeat will happen via the next call to
+ // scheduleFrameWrite. So keep a HEAP (priqueue) of which
+ // streams to write to.
+
// TODO: proper scheduler
wm := sc.writeQueue[0]
// shift it all down. kinda lame. will be removed later anyway.
@@ -646,7 +709,7 @@
code ErrCode
}
-func (sc *serverConn) writeGoAwayFrame(v interface{}) error {
+func (sc *serverConn) writeGoAwayFrame(_ uint32, v interface{}) error {
sc.writeG.check()
p := v.(*goAwayParams)
err := sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
@@ -672,7 +735,7 @@
sc.closeStream(st, se)
}
-func (sc *serverConn) writeRSTStreamFrame(v interface{}) error {
+func (sc *serverConn) writeRSTStreamFrame(streamID uint32, v interface{}) error {
sc.writeG.check()
se := v.(*StreamError)
return sc.framer.WriteRSTStream(se.StreamID, se.Code)
@@ -796,7 +859,7 @@
return nil
}
-func (sc *serverConn) writePingAck(v interface{}) error {
+func (sc *serverConn) writePingAck(_ uint32, v interface{}) error {
sc.writeG.check()
pf := v.(*PingFrame) // contains the data we need to write back
return sc.framer.WritePing(true, pf.Data)
@@ -871,7 +934,7 @@
return nil
}
-func (sc *serverConn) writeSettingsAck(_ interface{}) error {
+func (sc *serverConn) writeSettingsAck(uint32, interface{}) error {
return sc.framer.WriteSettingsAck()
}
@@ -1145,7 +1208,7 @@
rws.stream = rp.stream
rws.req = req
rws.body = body
- rws.chunkWrittenCh = make(chan error, 1)
+ rws.frameWriteCh = make(chan error, 1)
rw := &responseWriter{rws: rws}
return rw, req, nil
@@ -1170,7 +1233,7 @@
type frameWriteMsg struct {
// write runs on the writeFrames goroutine.
- write func(sc *serverConn, v interface{}) error
+ write func(sc *serverConn, streamID uint32, v interface{}) error
v interface{} // passed to write
cost uint32 // number of flow control bytes required
@@ -1196,7 +1259,7 @@
// called from handler goroutines.
// h may be nil.
-func (sc *serverConn) writeHeaders(req headerWriteReq) {
+func (sc *serverConn) writeHeaders(req headerWriteReq, tempCh chan error) {
sc.serveG.checkNotOn() // NOT on
var errc chan error
if req.h != nil {
@@ -1204,7 +1267,7 @@
// waiting for this frame to be written, so an http.Flush mid-handler
// writes out the correct value of keys, before a handler later potentially
// mutates it.
- errc = make(chan error, 1)
+ errc = tempCh
}
sc.writeFrame(frameWriteMsg{
write: (*serverConn).writeHeadersFrame,
@@ -1224,7 +1287,7 @@
}
}
-func (sc *serverConn) writeHeadersFrame(v interface{}) error {
+func (sc *serverConn) writeHeadersFrame(streamID uint32, v interface{}) error {
sc.writeG.check()
req := v.(headerWriteReq)
@@ -1265,33 +1328,30 @@
sc.serveG.checkNotOn() // NOT
sc.writeFrame(frameWriteMsg{
write: (*serverConn).write100ContinueHeadersFrame,
- v: st,
stream: st,
})
}
-func (sc *serverConn) write100ContinueHeadersFrame(v interface{}) error {
+func (sc *serverConn) write100ContinueHeadersFrame(streamID uint32, _ interface{}) error {
sc.writeG.check()
- st := v.(*stream)
sc.headerWriteBuf.Reset()
sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
return sc.framer.WriteHeaders(HeadersFrameParam{
- StreamID: st.id,
+ StreamID: streamID,
BlockFragment: sc.headerWriteBuf.Bytes(),
EndStream: false,
EndHeaders: true,
})
}
-func (sc *serverConn) writeDataFrame(v interface{}) error {
+func (sc *serverConn) writeDataFrame(streamID uint32, v interface{}) error {
sc.writeG.check()
- rws := v.(*responseWriterState)
- return sc.framer.WriteData(rws.stream.id, rws.curChunkIsFinal, rws.curChunk)
+ req := v.(*dataWriteRequest)
+ return sc.framer.WriteData(streamID, req.end, req.p)
}
type windowUpdateReq struct {
- stream *stream
- n uint32
+ n uint32
}
// called from handler goroutines
@@ -1303,7 +1363,7 @@
for n >= maxUint32 {
sc.writeFrame(frameWriteMsg{
write: (*serverConn).sendWindowUpdateInLoop,
- v: windowUpdateReq{st, maxUint32},
+ v: windowUpdateReq{maxUint32},
stream: st,
})
n -= maxUint32
@@ -1311,19 +1371,19 @@
if n > 0 {
sc.writeFrame(frameWriteMsg{
write: (*serverConn).sendWindowUpdateInLoop,
- v: windowUpdateReq{st, uint32(n)},
+ v: windowUpdateReq{uint32(n)},
stream: st,
})
}
}
-func (sc *serverConn) sendWindowUpdateInLoop(v interface{}) error {
+func (sc *serverConn) sendWindowUpdateInLoop(streamID uint32, v interface{}) error {
sc.writeG.check()
wu := v.(windowUpdateReq)
if err := sc.framer.WriteWindowUpdate(0, wu.n); err != nil {
return err
}
- if err := sc.framer.WriteWindowUpdate(wu.stream.id, wu.n); err != nil {
+ if err := sc.framer.WriteWindowUpdate(streamID, wu.n); err != nil {
return err
}
return nil
@@ -1397,15 +1457,24 @@
status int // status code passed to WriteHeader
sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished
-
- curChunk []byte // current chunk we're writing
- curChunkIsFinal bool
- chunkWrittenCh chan error
+ curWrite dataWriteRequest
+ frameWriteCh chan error // re-used whenever we need to block on a frame being written
closeNotifierMu sync.Mutex // guards closeNotifierCh
closeNotifierCh chan bool // nil until first used
}
+func (rws *responseWriterState) writeData(p []byte, end bool) error {
+ rws.curWrite.p = p
+ rws.curWrite.end = end
+ return rws.stream.conn.writeData(rws.stream, &rws.curWrite, rws.frameWriteCh)
+}
+
+type dataWriteRequest struct {
+ p []byte
+ end bool
+}
+
type chunkWriter struct{ rws *responseWriterState }
func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
@@ -1437,22 +1506,14 @@
endStream: endStream,
contentType: ctype,
contentLength: clen,
- })
+ }, rws.frameWriteCh)
if endStream {
return
}
}
if len(p) == 0 {
if rws.handlerDone {
- rws.curChunk = nil
- rws.curChunkIsFinal = true
- rws.stream.conn.writeFrame(frameWriteMsg{
- write: (*serverConn).writeDataFrame,
- cost: 0,
- stream: rws.stream,
- endStream: true,
- v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
- })
+ err = rws.writeData(nil, true)
}
return
}
@@ -1462,25 +1523,8 @@
chunk = chunk[:handlerChunkWriteSize]
}
p = p[len(chunk):]
- rws.curChunk = chunk
- rws.curChunkIsFinal = rws.handlerDone && len(p) == 0
-
- // TODO: await flow control tokens for both stream and conn
- rws.stream.conn.writeFrame(frameWriteMsg{
- write: (*serverConn).writeDataFrame,
- cost: uint32(len(chunk)),
- stream: rws.stream,
- endStream: rws.curChunkIsFinal,
- done: rws.chunkWrittenCh,
- v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
- })
- // Block until it's written, or if the client disconnects.
- select {
- case err = <-rws.chunkWrittenCh:
- case <-rws.stream.conn.doneServing:
- // Client disconnected.
- err = errClientDisconnected
- }
+ isFinal := rws.handlerDone && len(p) == 0
+ err = rws.writeData(chunk, isFinal)
if err != nil {
break
}