Add Framer.WriteRawFrame, add max frame tunable, reject large frames.
Also fix server GOAWAY handling in the process, and start to work on
reducing idle connection memory pinning, by making the frame reader
only block reading into a 9 byte buffer, opening the door to making
the rest of the frame buffer be lazily allocated.
diff --git a/server.go b/server.go
index c33eb3a..77b6832 100644
--- a/server.go
+++ b/server.go
@@ -49,6 +49,18 @@
type Server struct {
// MaxStreams optionally ...
MaxStreams int
+
+ // MaxReadFrameSize optionally specifies the largest frame
+ // this server is willing to read. A valid value is between
+ // 16k and 16M, inclusive.
+ MaxReadFrameSize uint32
+}
+
+func (s *Server) maxReadFrameSize() uint32 {
+ if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize {
+ return v
+ }
+ return defaultMaxReadFrameSize
}
var testHookOnConn func() // for testing
@@ -90,11 +102,17 @@
var testHookGetServerConn func(*serverConn)
func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) {
+ // TODO: write to a (custom?) buffered writer that can
+ // alternate when it's in buffered mode.
+ fr := NewFramer(c, c)
+ fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
+
sc := &serverConn{
+ srv: srv,
hs: hs,
conn: c,
handler: h,
- framer: NewFramer(c, c), // TODO: write to a (custom?) buffered writer that can alternate when it's in buffered mode.
+ framer: fr,
streams: make(map[uint32]*stream),
readFrameCh: make(chan frameAndGate),
readFrameErrCh: make(chan error, 1), // must be buffered for 1
@@ -130,6 +148,7 @@
type serverConn struct {
// Immutable:
+ srv *Server
hs *http.Server
conn net.Conn
handler http.Handler
@@ -159,10 +178,14 @@
maxHeaderListSize uint32 // zero means unknown (default)
maxConcurrentStreams int64 // negative means no limit.
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
- sentGoAway bool
- req requestParam // non-zero while reading request headers
- writingFrame bool // sent on writeFrameCh but haven't heard back on wroteFrameCh yet
- writeQueue []frameWriteMsg // TODO: proper scheduler, not a queue
+ req requestParam // non-zero while reading request headers
+ writingFrame bool // sent on writeFrameCh but haven't heard back on wroteFrameCh yet
+ writeQueue []frameWriteMsg // TODO: proper scheduler, not a queue
+ inGoAway bool // we've started to or sent GOAWAY
+ needToSendGoAway bool // we need to schedule a GOAWAY frame write
+ goAwayCode ErrCode
+ shutdownTimerCh <-chan time.Time
+ shutdownTimer *time.Timer
// Owned by the writeFrames goroutine; use writeG.check():
headerWriteBuf bytes.Buffer
@@ -362,10 +385,18 @@
}
}
+func (sc *serverConn) stopShutdownTimer() {
+ sc.serveG.check()
+ if t := sc.shutdownTimer; t != nil {
+ t.Stop()
+ }
+}
+
func (sc *serverConn) serve() {
sc.serveG.check()
defer sc.conn.Close()
defer sc.closeAllStreamsOnConnClose()
+ defer sc.stopShutdownTimer()
defer close(sc.doneServing) // unblocks handlers trying to send
defer close(sc.writeFrameCh) // stop the writeFrames loop
@@ -382,7 +413,6 @@
go sc.writeFrames() // source closed in stopServing
settingsTimer := time.NewTimer(firstSettingsTimeout)
-
for {
select {
case wm := <-sc.wantWriteFrameCh:
@@ -391,6 +421,9 @@
sc.writingFrame = false
sc.scheduleFrameWrite()
case fg, ok := <-sc.readFrameCh:
+ if !ok {
+ sc.readFrameCh = nil
+ }
if !sc.processFrameFromReader(fg, ok) {
return
}
@@ -401,6 +434,9 @@
case <-settingsTimer.C:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
return
+ case <-sc.shutdownTimerCh:
+ sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
+ return
case fn := <-sc.testHookCh:
fn()
}
@@ -410,7 +446,7 @@
func (sc *serverConn) sendInitialSettings(_ interface{}) error {
sc.writeG.check()
return sc.framer.WriteSettings(
- Setting{SettingMaxFrameSize, 1 << 20},
+ Setting{SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
Setting{SettingMaxConcurrentStreams, 250}, // TODO: tunable?
/* TODO: more actual settings */
)
@@ -508,8 +544,8 @@
func (sc *serverConn) enqueueSettingsAck() {
sc.serveG.check()
- // Fast path for common case:
if !sc.writingFrame {
+ sc.needToSendSettingsAck = false
sc.writeFrameCh <- frameWriteMsg{write: (*serverConn).writeSettingsAck}
return
}
@@ -519,10 +555,24 @@
func (sc *serverConn) scheduleFrameWrite() {
sc.serveG.check()
if sc.writingFrame {
- panic("invariant")
+ return
+ }
+ if sc.needToSendGoAway {
+ sc.needToSendGoAway = false
+ sc.sendFrameWrite(frameWriteMsg{
+ write: (*serverConn).writeGoAwayFrame,
+ v: &goAwayParams{
+ maxStreamID: sc.maxStreamID,
+ code: sc.goAwayCode,
+ },
+ })
+ return
+ }
+ if sc.inGoAway {
+ // No more frames after we've sent GOAWAY.
+ return
}
if sc.needToSendSettingsAck {
- sc.needToSendSettingsAck = false
sc.enqueueSettingsAck()
return
}
@@ -546,18 +596,25 @@
func (sc *serverConn) goAway(code ErrCode) {
sc.serveG.check()
- if sc.sentGoAway {
+ if sc.inGoAway {
return
}
- sc.sentGoAway = true
- // TODO: set a timer to see if they're gone at some point?
- sc.enqueueFrameWrite(frameWriteMsg{
- write: (*serverConn).writeGoAwayFrame,
- v: &goAwayParams{
- maxStreamID: sc.maxStreamID,
- code: code,
- },
- })
+ if code != ErrCodeNo {
+ sc.shutDownIn(250 * time.Millisecond)
+ } else {
+ // TODO: configurable
+ sc.shutDownIn(1 * time.Second)
+ }
+ sc.inGoAway = true
+ sc.needToSendGoAway = true
+ sc.goAwayCode = code
+ sc.scheduleFrameWrite()
+}
+
+func (sc *serverConn) shutDownIn(d time.Duration) {
+ sc.serveG.check()
+ sc.shutdownTimer = time.NewTimer(d)
+ sc.shutdownTimerCh = sc.shutdownTimer.C
}
type goAwayParams struct {
@@ -568,7 +625,15 @@
func (sc *serverConn) writeGoAwayFrame(v interface{}) error {
sc.writeG.check()
p := v.(*goAwayParams)
- return sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
+ err := sc.framer.WriteGoAway(p.maxStreamID, p.code, nil)
+ if p.code != 0 {
+ // TODO: flush any buffer, if we add a buffering writing
+ // Sleep a bit to give the peer a bit of time to read the
+ // GOAWAY before potentially getting a TCP RST packet:
+ time.Sleep(50 * time.Millisecond)
+ sc.conn.Close()
+ }
+ return err
}
func (sc *serverConn) resetStreamInLoop(se StreamError) {
@@ -607,6 +672,10 @@
sc.serveG.check()
if !fgValid {
err := <-sc.readFrameErrCh
+ if err == ErrFrameTooLarge {
+ sc.goAway(ErrCodeFrameSize)
+ return true // goAway will close the loop
+ }
if err != io.EOF {
errstr := err.Error()
if !strings.Contains(errstr, "use of closed network connection") {
@@ -634,7 +703,9 @@
sc.goAway(ErrCodeFlowControl)
return true
case ConnectionError:
- sc.logf("disconnecting; %v", ev)
+ sc.logf("%v: %v", sc.conn.RemoteAddr(), ev)
+ sc.goAway(ErrCode(ev))
+ return true // goAway will handle shutdown
default:
sc.logf("Disconnection due to other error: %v", err)
}
@@ -676,7 +747,7 @@
case *RSTStreamFrame:
return sc.processResetStream(f)
default:
- log.Printf("Ignoring unknown frame %#v", f)
+ log.Printf("Ignoring frame: %v", f.Header())
return nil
}
}
@@ -892,7 +963,7 @@
func (sc *serverConn) processHeaders(f *HeadersFrame) error {
sc.serveG.check()
id := f.Header().StreamID
- if sc.sentGoAway {
+ if sc.inGoAway {
// Ignore.
return nil
}