CloseNotifier support and associated tests and discovered bugs & fixes
diff --git a/server.go b/server.go
index a84f0f0..30f5758 100644
--- a/server.go
+++ b/server.go
@@ -97,7 +97,7 @@
readFrameErrCh: make(chan error, 1), // must be buffered for 1
wantWriteFrameCh: make(chan frameWriteMsg, 8),
writeFrameCh: make(chan frameWriteMsg, 1), // may be 0 or 1, but more is useless. (max 1 in flight)
- wroteFrameCh: make(chan struct{}, 1),
+ wroteFrameCh: make(chan struct{}, 1), // TODO: consider 0. will deadlock currently in sendFrameWrite in sentReset case
flow: newFlow(initialWindowSize),
doneServing: make(chan struct{}),
maxWriteFrameSize: initialMaxFrameSize,
@@ -180,14 +180,27 @@
invalidHeader bool // an invalid header was seen
}
+// stream represents an stream. This is the minimal metadata needed by
+// the serve goroutine. Most of the actual stream state is owned by
+// the http.Handler's goroutine in the responseWriter. Because the
+// responseWriter's responseWriterState is recycled at the end of a
+// handler, this struct intentionally has no pointer to the
+// *responseWriter{,State} itself, as the Handler ending nils out the
+// responseWriter's state field.
type stream struct {
- id uint32
- state streamState // owned by serverConn's serve loop
- flow *flow // limits writing from Handler to client
- body *pipe // non-nil if expecting DATA frames
+ // immutable:
+ id uint32
+ conn *serverConn
+ flow *flow // limits writing from Handler to client
+ body *pipe // non-nil if expecting DATA frames
+ cw closeWaiter // closed wait stream transitions to closed state
+ // owned by serverConn's serve loop:
+ state streamState
bodyBytes int64 // body bytes seen so far
declBodyBytes int64 // or -1 if undeclared
+ sentReset bool // only true once detached from streams map
+ gotReset bool // only true once detacted from streams map
}
func (sc *serverConn) state(streamID uint32) streamState {
@@ -337,12 +350,13 @@
}
}
+var errClientDisconnected = errors.New("client disconnected")
+
func (sc *serverConn) stopServing() {
sc.serveG.check()
close(sc.writeFrameCh) // stop the writeFrames loop
- err := errors.New("client disconnected")
- for id := range sc.streams {
- sc.closeStream(id, err)
+ for _, st := range sc.streams {
+ sc.closeStream(st, errClientDisconnected)
}
}
@@ -449,17 +463,32 @@
if sc.writingFrame {
panic("invariant")
}
+
+ st := wm.stream
+ if st != nil {
+ switch st.state {
+ case stateHalfClosedLocal:
+ panic("internal error: attempt to send frame on half-closed-local stream")
+ case stateClosed:
+ if st.sentReset || st.gotReset {
+ // Skip this frame. But fake the frame write to reschedule:
+ sc.wroteFrameCh <- struct{}{}
+ return
+ }
+ panic("internal error: attempt to send a frame on a closed stream")
+ }
+ }
+
sc.writingFrame = true
if wm.endStream {
- st, ok := sc.streams[wm.streamID]
- if ok {
- switch st.state {
- case stateOpen:
- st.state = stateHalfClosedLocal
- case stateHalfClosedRemote:
- st.state = stateClosed
- delete(sc.streams, wm.streamID)
- }
+ if st == nil {
+ panic("nil stream with endStream set")
+ }
+ switch st.state {
+ case stateOpen:
+ st.state = stateHalfClosedLocal
+ case stateHalfClosedRemote:
+ sc.closeStream(st, nil)
}
}
sc.writeFrameCh <- wm
@@ -530,20 +559,24 @@
return sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
}
-func (sc *serverConn) resetStreamInLoop(se StreamError) error {
+func (sc *serverConn) resetStreamInLoop(se StreamError) {
sc.serveG.check()
- delete(sc.streams, se.streamID)
+ st, ok := sc.streams[se.StreamID]
+ if !ok {
+ panic(fmt.Sprintf("invariant. closing non-open stream %d", se.StreamID))
+ }
sc.enqueueFrameWrite(frameWriteMsg{
write: (*serverConn).writeRSTStreamFrame,
v: &se,
})
- return nil
+ st.sentReset = true
+ sc.closeStream(st, se)
}
func (sc *serverConn) writeRSTStreamFrame(v interface{}) error {
sc.writeG.check()
se := v.(*StreamError)
- return sc.framer.WriteRSTStream(se.streamID, se.code)
+ return sc.framer.WriteRSTStream(se.StreamID, se.Code)
}
func (sc *serverConn) curHeaderStreamID() uint32 {
@@ -583,10 +616,7 @@
switch ev := err.(type) {
case StreamError:
- if err := sc.resetStreamInLoop(ev); err != nil {
- sc.logf("Error writing RSTSTream: %v", err)
- return false
- }
+ sc.resetStreamInLoop(ev)
return true
case goAwayFlowError:
sc.goAway(ErrCodeFlowControl)
@@ -693,22 +723,34 @@
func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
sc.serveG.check()
- sc.closeStream(f.StreamID, StreamError{f.StreamID, f.ErrCode})
+ if sc.state(f.StreamID) == stateIdle {
+ // 6.4 "RST_STREAM frames MUST NOT be sent for a
+ // stream in the "idle" state. If a RST_STREAM frame
+ // identifying an idle stream is received, the
+ // recipient MUST treat this as a connection error
+ // (Section 5.4.1) of type PROTOCOL_ERROR.
+ return ConnectionError(ErrCodeProtocol)
+ }
+ st, ok := sc.streams[f.StreamID]
+ if ok {
+ st.gotReset = true
+ sc.closeStream(st, StreamError{f.StreamID, f.ErrCode})
+ }
return nil
}
-func (sc *serverConn) closeStream(streamID uint32, err error) {
+func (sc *serverConn) closeStream(st *stream, err error) {
sc.serveG.check()
- st, ok := sc.streams[streamID]
- if !ok {
- return
+ if st.state == stateIdle || st.state == stateClosed {
+ panic("invariant")
}
- st.state = stateClosed // kinda useless
- delete(sc.streams, streamID)
+ st.state = stateClosed
+ delete(sc.streams, st.id)
st.flow.close()
if p := st.body; p != nil {
p.Close(err)
}
+ st.cw.Close() // signals Handler's CloseNotifier goroutine (if any) to send
}
func (sc *serverConn) processSettings(f *SettingsFrame) error {
@@ -818,7 +860,7 @@
}
st.bodyBytes += int64(len(data))
}
- if f.Header().Flags.Has(FlagDataEndStream) {
+ if f.StreamEnded() {
if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
st.body.Close(fmt.Errorf("Request declared a Content-Length of %d but only wrote %d bytes",
st.declBodyBytes, st.bodyBytes))
@@ -857,11 +899,13 @@
sc.maxStreamID = id
}
st := &stream{
+ conn: sc,
id: id,
state: stateOpen,
flow: newFlow(sc.initialWindowSize),
}
- if f.Header().Flags.Has(FlagHeadersEndStream) {
+ st.cw.Init() // make Cond use its Mutex, without heap-promoting them separately
+ if f.StreamEnded() {
st.state = stateHalfClosedRemote
}
sc.streams[id] = st
@@ -938,7 +982,7 @@
bodyOpen := rp.stream.state == stateOpen
body := &requestBody{
sc: sc,
- streamID: rp.stream.id,
+ stream: rp.stream,
needsContinue: needsContinue,
}
url, err := url.ParseRequestURI(rp.path)
@@ -977,8 +1021,7 @@
*rws = responseWriterState{} // zero all the fields
rws.bw = bwSave
rws.bw.Reset(chunkWriter{rws})
- rws.sc = sc
- rws.streamID = rp.stream.id
+ rws.stream = rp.stream
rws.req = req
rws.body = body
rws.chunkWrittenCh = make(chan error, 1)
@@ -1010,7 +1053,7 @@
v interface{} // passed to write
cost uint32 // number of flow control bytes required
- streamID uint32 // used for prioritization
+ stream *stream // used for prioritization
endStream bool // streamID is being closed locally
// done, if non-nil, must be a buffered channel with space for
@@ -1021,7 +1064,7 @@
// headerWriteReq is a request to write an HTTP response header from a server Handler.
type headerWriteReq struct {
- streamID uint32
+ stream *stream
httpResCode int
h http.Header // may be nil
endStream bool
@@ -1033,7 +1076,7 @@
// called from handler goroutines.
// h may be nil.
func (sc *serverConn) writeHeaders(req headerWriteReq) {
- sc.serveG.checkNotOn()
+ sc.serveG.checkNotOn() // NOT on
var errc chan error
if req.h != nil {
// If there's a header map (which we don't own), so we have to block on
@@ -1045,7 +1088,7 @@
sc.writeFrame(frameWriteMsg{
write: (*serverConn).writeHeadersFrame,
v: req,
- streamID: req.streamID,
+ stream: req.stream,
done: errc,
endStream: req.endStream,
})
@@ -1083,7 +1126,7 @@
panic("TODO")
}
return sc.framer.WriteHeaders(HeadersFrameParam{
- StreamID: req.streamID,
+ StreamID: req.stream.id,
BlockFragment: headerBlock,
EndStream: req.endStream,
EndHeaders: true, // no continuation yet
@@ -1091,23 +1134,22 @@
}
// called from handler goroutines.
-// h may be nil.
-func (sc *serverConn) write100ContinueHeaders(streamID uint32) {
- sc.serveG.checkNotOn()
+func (sc *serverConn) write100ContinueHeaders(st *stream) {
+ sc.serveG.checkNotOn() // NOT
sc.writeFrame(frameWriteMsg{
- write: (*serverConn).write100ContinueHeadersFrame,
- v: &streamID,
- streamID: streamID,
+ write: (*serverConn).write100ContinueHeadersFrame,
+ v: st,
+ stream: st,
})
}
func (sc *serverConn) write100ContinueHeadersFrame(v interface{}) error {
sc.writeG.check()
- streamID := *(v.(*uint32))
+ st := v.(*stream)
sc.headerWriteBuf.Reset()
sc.hpackEncoder.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
return sc.framer.WriteHeaders(HeadersFrameParam{
- StreamID: streamID,
+ StreamID: st.id,
BlockFragment: sc.headerWriteBuf.Bytes(),
EndStream: false,
EndHeaders: true,
@@ -1117,30 +1159,33 @@
func (sc *serverConn) writeDataFrame(v interface{}) error {
sc.writeG.check()
rws := v.(*responseWriterState)
- return sc.framer.WriteData(rws.streamID, rws.curChunkIsFinal, rws.curChunk)
+ return sc.framer.WriteData(rws.stream.id, rws.curChunkIsFinal, rws.curChunk)
}
type windowUpdateReq struct {
- streamID uint32
- n uint32
+ stream *stream
+ n uint32
}
// called from handler goroutines
-func (sc *serverConn) sendWindowUpdate(streamID uint32, n int) {
+func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
+ if st == nil {
+ panic("no stream")
+ }
const maxUint32 = 2147483647
for n >= maxUint32 {
sc.writeFrame(frameWriteMsg{
- write: (*serverConn).sendWindowUpdateInLoop,
- v: windowUpdateReq{streamID, maxUint32},
- streamID: streamID,
+ write: (*serverConn).sendWindowUpdateInLoop,
+ v: windowUpdateReq{st, maxUint32},
+ stream: st,
})
n -= maxUint32
}
if n > 0 {
sc.writeFrame(frameWriteMsg{
- write: (*serverConn).sendWindowUpdateInLoop,
- v: windowUpdateReq{streamID, uint32(n)},
- streamID: streamID,
+ write: (*serverConn).sendWindowUpdateInLoop,
+ v: windowUpdateReq{st, uint32(n)},
+ stream: st,
})
}
}
@@ -1151,7 +1196,7 @@
if err := sc.framer.WriteWindowUpdate(0, wu.n); err != nil {
return err
}
- if err := sc.framer.WriteWindowUpdate(wu.streamID, wu.n); err != nil {
+ if err := sc.framer.WriteWindowUpdate(wu.stream.id, wu.n); err != nil {
return err
}
return nil
@@ -1159,7 +1204,7 @@
type requestBody struct {
sc *serverConn
- streamID uint32
+ stream *stream
closed bool
pipe *pipe // non-nil if we have a HTTP entity message body
needsContinue bool // need to send a 100-continue
@@ -1178,14 +1223,14 @@
func (b *requestBody) Read(p []byte) (n int, err error) {
if b.needsContinue {
b.needsContinue = false
- b.sc.write100ContinueHeaders(b.streamID)
+ b.sc.write100ContinueHeaders(b.stream)
}
if b.pipe == nil {
return 0, io.EOF
}
n, err = b.pipe.Read(p)
if n > 0 {
- b.sc.sendWindowUpdate(b.streamID, n)
+ b.sc.sendWindowUpdate(b.stream, n)
// TODO: tell b.sc to send back 'n' flow control quota credits to the sender
}
return
@@ -1203,17 +1248,17 @@
// Optional http.ResponseWriter interfaces implemented.
var (
- _ http.Flusher = (*responseWriter)(nil)
- _ stringWriter = (*responseWriter)(nil)
+ _ http.CloseNotifier = (*responseWriter)(nil)
+ _ http.Flusher = (*responseWriter)(nil)
+ _ stringWriter = (*responseWriter)(nil)
// TODO: hijacker for websockets?
)
type responseWriterState struct {
// immutable within a request:
- sc *serverConn
- streamID uint32
- req *http.Request
- body *requestBody // to close at end of request, if DATA frames didn't
+ stream *stream
+ req *http.Request
+ body *requestBody // to close at end of request, if DATA frames didn't
// TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc
bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState}
@@ -1229,6 +1274,9 @@
curChunk []byte // current chunk we're writing
curChunkIsFinal bool
chunkWrittenCh chan error
+
+ closeNotifierMu sync.Mutex // guards closeNotifierCh
+ closeNotifierCh chan bool // nil until first used
}
type chunkWriter struct{ rws *responseWriterState }
@@ -1255,8 +1303,8 @@
ctype = http.DetectContentType(p)
}
endStream := rws.handlerDone && len(p) == 0
- rws.sc.writeHeaders(headerWriteReq{
- streamID: rws.streamID,
+ rws.stream.conn.writeHeaders(headerWriteReq{
+ stream: rws.stream,
httpResCode: rws.status,
h: rws.snapHeader,
endStream: endStream,
@@ -1271,10 +1319,10 @@
if rws.handlerDone {
rws.curChunk = nil
rws.curChunkIsFinal = true
- rws.sc.writeFrame(frameWriteMsg{
+ rws.stream.conn.writeFrame(frameWriteMsg{
write: (*serverConn).writeDataFrame,
cost: 0,
- streamID: rws.streamID,
+ stream: rws.stream,
endStream: true,
v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
})
@@ -1291,10 +1339,10 @@
rws.curChunkIsFinal = rws.handlerDone && len(p) == 0
// TODO: await flow control tokens for both stream and conn
- rws.sc.writeFrame(frameWriteMsg{
+ rws.stream.conn.writeFrame(frameWriteMsg{
write: (*serverConn).writeDataFrame,
cost: uint32(len(chunk)),
- streamID: rws.streamID,
+ stream: rws.stream,
endStream: rws.curChunkIsFinal,
done: rws.chunkWrittenCh,
v: rws, // writeDataInLoop uses only rws.curChunk and rws.curChunkIsFinal
@@ -1327,6 +1375,25 @@
}
}
+func (w *responseWriter) CloseNotify() <-chan bool {
+ rws := w.rws
+ if rws == nil {
+ panic("CloseNotify called after Handler finished")
+ }
+ rws.closeNotifierMu.Lock()
+ ch := rws.closeNotifierCh
+ if ch == nil {
+ ch = make(chan bool, 1)
+ rws.closeNotifierCh = ch
+ go func() {
+ rws.stream.cw.Wait() // wait for close
+ ch <- true
+ }()
+ }
+ rws.closeNotifierMu.Unlock()
+ return ch
+}
+
func (w *responseWriter) Header() http.Header {
rws := w.rws
if rws == nil {
@@ -1400,7 +1467,6 @@
}
rws.handlerDone = true
w.Flush()
-
w.rws = nil
responseWriterStatePool.Put(rws)
}