http2: remove extra goroutine stack from awaitGracefulShutdown
This is a better fix that https://golang.org/cl/43455. Instead of
creating a separate goroutine to wait for the global shutdown channel,
we reuse the new serverMsgCh, which was added in a prior CL.
We also use the new net/http.Server.RegisterOnShutdown method to
register a shutdown callback for each http2.Server.
Updates golang/go#20302
Updates golang/go#18471
Change-Id: Icf29d5e4f65b3779d1fb4ea92924e4fb6bdadb2a
Reviewed-on: https://go-review.googlesource.com/43230
Run-TryBot: Tom Bergan <tombergan@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/go19.go b/http2/go19.go
new file mode 100644
index 0000000..38124ba
--- /dev/null
+++ b/http2/go19.go
@@ -0,0 +1,16 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package http2
+
+import (
+ "net/http"
+)
+
+func configureServer19(s *http.Server, conf *Server) error {
+ s.RegisterOnShutdown(conf.state.startGracefulShutdown)
+ return nil
+}
diff --git a/http2/go19_test.go b/http2/go19_test.go
new file mode 100644
index 0000000..1675d24
--- /dev/null
+++ b/http2/go19_test.go
@@ -0,0 +1,60 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package http2
+
+import (
+ "context"
+ "net/http"
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestServerGracefulShutdown(t *testing.T) {
+ var st *serverTester
+ handlerDone := make(chan struct{})
+ st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ defer close(handlerDone)
+ go st.ts.Config.Shutdown(context.Background())
+
+ ga := st.wantGoAway()
+ if ga.ErrCode != ErrCodeNo {
+ t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
+ }
+ if ga.LastStreamID != 1 {
+ t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
+ }
+
+ w.Header().Set("x-foo", "bar")
+ })
+ defer st.Close()
+
+ st.greet()
+ st.bodylessReq1()
+
+ select {
+ case <-handlerDone:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("server did not shutdown?")
+ }
+ hf := st.wantHeaders()
+ goth := st.decodeHeader(hf.HeaderBlockFragment())
+ wanth := [][2]string{
+ {":status", "200"},
+ {"x-foo", "bar"},
+ {"content-type", "text/plain; charset=utf-8"},
+ {"content-length", "0"},
+ }
+ if !reflect.DeepEqual(goth, wanth) {
+ t.Errorf("Got headers %v; want %v", goth, wanth)
+ }
+
+ n, err := st.cc.Read([]byte{0})
+ if n != 0 || err == nil {
+ t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
+ }
+}
diff --git a/http2/not_go19.go b/http2/not_go19.go
new file mode 100644
index 0000000..5ae0772
--- /dev/null
+++ b/http2/not_go19.go
@@ -0,0 +1,16 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.9
+
+package http2
+
+import (
+ "net/http"
+)
+
+func configureServer19(s *http.Server, conf *Server) error {
+ // not supported prior to go1.9
+ return nil
+}
diff --git a/http2/server.go b/http2/server.go
index 3175a08..7367b31 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -126,6 +126,11 @@
// NewWriteScheduler constructs a write scheduler for a connection.
// If nil, a default scheduler is chosen.
NewWriteScheduler func() WriteScheduler
+
+ // Internal state. This is a pointer (rather than embedded directly)
+ // so that we don't embed a Mutex in this struct, which will make the
+ // struct non-copyable, which might break some callers.
+ state *serverInternalState
}
func (s *Server) initialConnRecvWindowSize() int32 {
@@ -156,6 +161,40 @@
return defaultMaxStreams
}
+type serverInternalState struct {
+ mu sync.Mutex
+ activeConns map[*serverConn]struct{}
+}
+
+func (s *serverInternalState) registerConn(sc *serverConn) {
+ if s == nil {
+ return // if the Server was used without calling ConfigureServer
+ }
+ s.mu.Lock()
+ s.activeConns[sc] = struct{}{}
+ s.mu.Unlock()
+}
+
+func (s *serverInternalState) unregisterConn(sc *serverConn) {
+ if s == nil {
+ return // if the Server was used without calling ConfigureServer
+ }
+ s.mu.Lock()
+ delete(s.activeConns, sc)
+ s.mu.Unlock()
+}
+
+func (s *serverInternalState) startGracefulShutdown() {
+ if s == nil {
+ return // if the Server was used without calling ConfigureServer
+ }
+ s.mu.Lock()
+ for sc := range s.activeConns {
+ sc.startGracefulShutdown()
+ }
+ s.mu.Unlock()
+}
+
// ConfigureServer adds HTTP/2 support to a net/http Server.
//
// The configuration conf may be nil.
@@ -168,9 +207,13 @@
if conf == nil {
conf = new(Server)
}
+ conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})}
if err := configureServer18(s, conf); err != nil {
return err
}
+ if err := configureServer19(s, conf); err != nil {
+ return err
+ }
if s.TLSConfig == nil {
s.TLSConfig = new(tls.Config)
@@ -305,6 +348,9 @@
pushEnabled: true,
}
+ s.state.registerConn(sc)
+ defer s.state.unregisterConn(sc)
+
// The net/http package sets the write deadline from the
// http.Server.WriteTimeout during the TLS handshake, but then
// passes the connection off to us with the deadline already set.
@@ -445,6 +491,9 @@
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
hpackEncoder *hpack.Encoder
+
+ // Used by startGracefulShutdown.
+ shutdownOnce sync.Once
}
func (sc *serverConn) maxHeaderListSize() uint32 {
@@ -749,15 +798,6 @@
defer sc.idleTimer.Stop()
}
- var gracefulShutdownCh chan struct{}
- if sc.hs != nil {
- ch := h1ServerShutdownChan(sc.hs)
- if ch != nil {
- gracefulShutdownCh = make(chan struct{})
- go sc.awaitGracefulShutdown(ch, gracefulShutdownCh)
- }
- }
-
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
@@ -786,14 +826,11 @@
}
case m := <-sc.bodyReadCh:
sc.noteBodyRead(m.st, m.n)
- case <-gracefulShutdownCh:
- gracefulShutdownCh = nil
- sc.startGracefulShutdown()
case msg := <-sc.serveMsgCh:
switch v := msg.(type) {
case func(int):
v(loopNum) // for testing
- case *timerMessage:
+ case *serverMessage:
switch v {
case settingsTimerMsg:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
@@ -804,6 +841,8 @@
case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
+ case gracefulShutdownMsg:
+ sc.startGracefulShutdownInternal()
default:
panic("unknown timer")
}
@@ -828,13 +867,14 @@
}
}
-type timerMessage int
+type serverMessage int
-// Timeout message values sent to serveMsgCh.
+// Message values sent to serveMsgCh.
var (
- settingsTimerMsg = new(timerMessage)
- idleTimerMsg = new(timerMessage)
- shutdownTimerMsg = new(timerMessage)
+ settingsTimerMsg = new(serverMessage)
+ idleTimerMsg = new(serverMessage)
+ shutdownTimerMsg = new(serverMessage)
+ gracefulShutdownMsg = new(serverMessage)
)
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
@@ -1166,10 +1206,19 @@
sc.inFrameScheduleLoop = false
}
-// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the
-// client we're gracefully shutting down. The connection isn't closed
-// until all current streams are done.
+// startGracefulShutdown gracefully shuts down a connection. This
+// sends GOAWAY with ErrCodeNo to tell the client we're gracefully
+// shutting down. The connection isn't closed until all current
+// streams are done.
+//
+// startGracefulShutdown returns immediately; it does not wait until
+// the connection has shut down.
func (sc *serverConn) startGracefulShutdown() {
+ sc.serveG.checkNotOn() // NOT
+ sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) })
+}
+
+func (sc *serverConn) startGracefulShutdownInternal() {
sc.goAwayIn(ErrCodeNo, 0)
}
@@ -1399,7 +1448,7 @@
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if h1ServerKeepAlivesDisabled(sc.hs) {
- sc.startGracefulShutdown()
+ sc.startGracefulShutdownInternal()
}
}
if p := st.body; p != nil {
@@ -1586,7 +1635,7 @@
} else {
sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
}
- sc.startGracefulShutdown()
+ sc.startGracefulShutdownInternal()
// http://tools.ietf.org/html/rfc7540#section-6.8
// We should not create any new streams, which means we should disable push.
sc.pushEnabled = false
@@ -2653,7 +2702,7 @@
// A server that is unable to establish a new stream identifier can send a GOAWAY
// frame so that the client is forced to open a new connection for new streams.
if sc.maxPushPromiseID+2 >= 1<<31 {
- sc.startGracefulShutdown()
+ sc.startGracefulShutdownInternal()
return 0, ErrPushLimitReached
}
sc.maxPushPromiseID += 2
@@ -2778,31 +2827,6 @@
"Www-Authenticate": true,
}
-// h1ServerShutdownChan returns a channel that will be closed when the
-// provided *http.Server wants to shut down.
-//
-// This is a somewhat hacky way to get at http1 innards. It works
-// when the http2 code is bundled into the net/http package in the
-// standard library. The alternatives ended up making the cmd/go tool
-// depend on http Servers. This is the lightest option for now.
-// This is tested via the TestServeShutdown* tests in net/http.
-func h1ServerShutdownChan(hs *http.Server) <-chan struct{} {
- if fn := testh1ServerShutdownChan; fn != nil {
- return fn(hs)
- }
- var x interface{} = hs
- type I interface {
- getDoneChan() <-chan struct{}
- }
- if hs, ok := x.(I); ok {
- return hs.getDoneChan()
- }
- return nil
-}
-
-// optional test hook for h1ServerShutdownChan.
-var testh1ServerShutdownChan func(hs *http.Server) <-chan struct{}
-
// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
// disabled. See comments on h1ServerShutdownChan above for why
// the code is written this way.
diff --git a/http2/server_test.go b/http2/server_test.go
index 5cb2490..638d2a4 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -3685,48 +3685,3 @@
<-done
}
}
-
-func TestServerGracefulShutdown(t *testing.T) {
- shutdownCh := make(chan struct{})
- defer func() { testh1ServerShutdownChan = nil }()
- testh1ServerShutdownChan = func(*http.Server) <-chan struct{} { return shutdownCh }
-
- var st *serverTester
- handlerDone := make(chan struct{})
- st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
- defer close(handlerDone)
- close(shutdownCh)
-
- ga := st.wantGoAway()
- if ga.ErrCode != ErrCodeNo {
- t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
- }
- if ga.LastStreamID != 1 {
- t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
- }
-
- w.Header().Set("x-foo", "bar")
- })
- defer st.Close()
-
- st.greet()
- st.bodylessReq1()
-
- <-handlerDone
- hf := st.wantHeaders()
- goth := st.decodeHeader(hf.HeaderBlockFragment())
- wanth := [][2]string{
- {":status", "200"},
- {"x-foo", "bar"},
- {"content-type", "text/plain; charset=utf-8"},
- {"content-length", "0"},
- }
- if !reflect.DeepEqual(goth, wanth) {
- t.Errorf("Got headers %v; want %v", goth, wanth)
- }
-
- n, err := st.cc.Read([]byte{0})
- if n != 0 || err == nil {
- t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
- }
-}