Fix stream state transitions and remove from map when states goes to CLOSED.
diff --git a/server.go b/server.go
index 084225e..4f30990 100644
--- a/server.go
+++ b/server.go
@@ -427,13 +427,37 @@
sc.serveG.check()
// Fast path for common case:
if !sc.writingFrame {
- sc.writingFrame = true
- sc.writeFrameCh <- wm
+ sc.sendFrameWrite(wm)
return
}
sc.writeQueue = append(sc.writeQueue, wm) // TODO: proper scheduler
}
+// sendFrameWrite sends a frame to the writeFrames goroutine.
+// Only one frame can be in-flight at a time.
+// sendFrameWrite also updates stream state right before the frame is
+// sent to be written.
+func (sc *serverConn) sendFrameWrite(wm frameWriteMsg) {
+ sc.serveG.check()
+ if sc.writingFrame {
+ panic("invariant")
+ }
+ sc.writingFrame = true
+ if wm.endStream {
+ st, ok := sc.streams[wm.streamID]
+ if ok {
+ switch st.state {
+ case stateOpen:
+ st.state = stateHalfClosedLocal
+ case stateHalfClosedRemote:
+ st.state = stateClosed
+ delete(sc.streams, wm.streamID)
+ }
+ }
+ }
+ sc.writeFrameCh <- wm
+}
+
func (sc *serverConn) enqueueSettingsAck() {
sc.serveG.check()
// Fast path for common case:
@@ -469,8 +493,7 @@
// (because a SETTINGS frame changed our max frame size while
// a stream was open and writing) and cut it up into smaller
// bits.
- sc.writingFrame = true
- sc.writeFrameCh <- wm
+ sc.sendFrameWrite(wm)
}
func (sc *serverConn) goAway(code ErrCode) {
@@ -780,6 +803,12 @@
} else {
st.body.Close(io.EOF)
}
+ switch st.state {
+ case stateOpen:
+ st.state = stateHalfClosedRemote
+ case stateHalfClosedLocal:
+ st.state = stateClosed
+ }
}
return nil
}
@@ -957,9 +986,10 @@
// write runs on the writeFrames goroutine.
write func(sc *serverConn, v interface{}) error
- v interface{} // passed to write
- cost uint32 // number of flow control bytes required
- streamID uint32 // used for prioritization
+ v interface{} // passed to write
+ cost uint32 // number of flow control bytes required
+ streamID uint32 // used for prioritization
+ endStream bool // streamID is being closed locally
// done, if non-nil, must be a buffered channel with space for
// 1 message and is sent the return value from write (or an
@@ -991,10 +1021,11 @@
errc = make(chan error, 1)
}
sc.writeFrame(frameWriteMsg{
- write: (*serverConn).writeHeadersFrame,
- v: req,
- streamID: req.streamID,
- done: errc,
+ write: (*serverConn).writeHeadersFrame,
+ v: req,
+ streamID: req.streamID,
+ done: errc,
+ endStream: req.endStream,
})
if errc != nil {
<-errc
@@ -1180,26 +1211,14 @@
type chunkWriter struct{ rws *responseWriterState }
-// chunkWriter.Write is called from bufio.Writer. Because bufio.Writer passes through large
-// writes, we break them up here if they're too big.
-func (cw chunkWriter) Write(p []byte) (n int, err error) {
- for len(p) > 0 {
- chunk := p
- if len(chunk) > handlerChunkWriteSize {
- chunk = chunk[:handlerChunkWriteSize]
- }
- _, err = cw.rws.writeChunk(chunk)
- if err != nil {
- return
- }
- n += len(chunk)
- p = p[len(chunk):]
- }
- return n, nil
-}
+func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
-// writeChunk writes small (max 4k, or handlerChunkWriteSize) chunks.
-// It's also responsible for sending the HEADER response.
+// writeChunk writes chunks from the bufio.Writer. But because
+// bufio.Writer may bypass its chunking, sometimes p may be
+// arbitrarily large.
+//
+// writeChunk is also responsible (on the first chunk) for sending the
+// HEADER response.
func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
if !rws.wroteHeader {
rws.writeHeader(200)
@@ -1213,31 +1232,58 @@
if rws.snapHeader.Get("Content-Type") == "" {
ctype = http.DetectContentType(p)
}
+ endStream := rws.handlerDone && len(p) == 0
rws.sc.writeHeaders(headerWriteReq{
streamID: rws.streamID,
httpResCode: rws.status,
h: rws.snapHeader,
- endStream: rws.handlerDone && len(p) == 0,
+ endStream: endStream,
contentType: ctype,
contentLength: clen,
})
+ if endStream {
+ return
+ }
}
- if len(p) == 0 && !rws.handlerDone {
+ if len(p) == 0 {
+ if rws.handlerDone {
+ rws.curChunk = nil
+ rws.curChunkIsFinal = true
+ rws.sc.writeFrame(frameWriteMsg{
+ write: (*serverConn).writeDataFrame,
+ cost: 0,
+ streamID: rws.streamID,
+ endStream: true,
+ v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
+ })
+ }
return
}
- rws.curChunk = p
- rws.curChunkIsFinal = rws.handlerDone
+ for len(p) > 0 {
+ chunk := p
+ if len(chunk) > handlerChunkWriteSize {
+ chunk = chunk[:handlerChunkWriteSize]
+ }
+ p = p[len(chunk):]
+ rws.curChunk = chunk
+ rws.curChunkIsFinal = rws.handlerDone && len(p) == 0
- // TODO: await flow control tokens for both stream and conn
- rws.sc.writeFrame(frameWriteMsg{
- cost: uint32(len(p)),
- streamID: rws.streamID,
- write: (*serverConn).writeDataFrame,
- done: rws.chunkWrittenCh,
- v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
- })
- err = <-rws.chunkWrittenCh // block until it's written
- return len(p), err
+ // TODO: await flow control tokens for both stream and conn
+ rws.sc.writeFrame(frameWriteMsg{
+ write: (*serverConn).writeDataFrame,
+ cost: uint32(len(chunk)),
+ streamID: rws.streamID,
+ endStream: rws.curChunkIsFinal,
+ done: rws.chunkWrittenCh,
+ v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
+ })
+ err = <-rws.chunkWrittenCh // block until it's written
+ if err != nil {
+ break
+ }
+ n += len(chunk)
+ }
+ return
}
func (w *responseWriter) Flush() {